未验证 提交 8c23840c 编写于 作者: A Artem Zuikov 提交者: GitHub

ExpressionActions refactoring: extract ArrayJoinAction class (#8998)

refactoring: extract ArrayJoinAction class
上级 3c1735d9
#include <Common/typeid_cast.h>
#include <Columns/ColumnArray.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Interpreters/Context.h>
#include <Interpreters/ArrayJoinAction.h>
namespace DB
{
namespace ErrorCodes
{
extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
extern const int TYPE_MISMATCH;
}
ArrayJoinAction::ArrayJoinAction(const NameSet & array_joined_columns_, bool array_join_is_left, const Context & context)
: columns(array_joined_columns_)
, is_left(array_join_is_left)
, is_unaligned(context.getSettingsRef().enable_unaligned_array_join)
{
if (columns.empty())
throw Exception("No arrays to join", ErrorCodes::LOGICAL_ERROR);
if (is_unaligned)
{
function_length = FunctionFactory::instance().get("length", context);
function_greatest = FunctionFactory::instance().get("greatest", context);
function_arrayResize = FunctionFactory::instance().get("arrayResize", context);
}
else if (is_left)
function_builder = FunctionFactory::instance().get("emptyArrayToSingle", context);
}
void ArrayJoinAction::prepare(Block & sample_block)
{
for (const auto & name : columns)
{
ColumnWithTypeAndName & current = sample_block.getByName(name);
const DataTypeArray * array_type = typeid_cast<const DataTypeArray *>(&*current.type);
if (!array_type)
throw Exception("ARRAY JOIN requires array argument", ErrorCodes::TYPE_MISMATCH);
current.type = array_type->getNestedType();
current.column = nullptr;
}
}
void ArrayJoinAction::execute(Block & block, bool dry_run)
{
if (columns.empty())
throw Exception("No arrays to join", ErrorCodes::LOGICAL_ERROR);
ColumnPtr any_array_ptr = block.getByName(*columns.begin()).column->convertToFullColumnIfConst();
const ColumnArray * any_array = typeid_cast<const ColumnArray *>(&*any_array_ptr);
if (!any_array)
throw Exception("ARRAY JOIN of not array: " + *columns.begin(), ErrorCodes::TYPE_MISMATCH);
/// If LEFT ARRAY JOIN, then we create columns in which empty arrays are replaced by arrays with one element - the default value.
std::map<String, ColumnPtr> non_empty_array_columns;
if (is_unaligned)
{
/// Resize all array joined columns to the longest one, (at least 1 if LEFT ARRAY JOIN), padded with default values.
auto rows = block.rows();
auto uint64 = std::make_shared<DataTypeUInt64>();
ColumnWithTypeAndName column_of_max_length;
if (is_left)
column_of_max_length = ColumnWithTypeAndName(uint64->createColumnConst(rows, 1u), uint64, {});
else
column_of_max_length = ColumnWithTypeAndName(uint64->createColumnConst(rows, 0u), uint64, {});
for (const auto & name : columns)
{
auto & src_col = block.getByName(name);
Block tmp_block{src_col, {{}, uint64, {}}};
function_length->build({src_col})->execute(tmp_block, {0}, 1, rows);
Block tmp_block2{
column_of_max_length, tmp_block.safeGetByPosition(1), {{}, uint64, {}}};
function_greatest->build({column_of_max_length, tmp_block.safeGetByPosition(1)})->execute(tmp_block2, {0, 1}, 2, rows);
column_of_max_length = tmp_block2.safeGetByPosition(2);
}
for (const auto & name : columns)
{
auto & src_col = block.getByName(name);
Block tmp_block{src_col, column_of_max_length, {{}, src_col.type, {}}};
function_arrayResize->build({src_col, column_of_max_length})->execute(tmp_block, {0, 1}, 2, rows);
src_col.column = tmp_block.safeGetByPosition(2).column;
any_array_ptr = src_col.column->convertToFullColumnIfConst();
}
any_array = typeid_cast<const ColumnArray *>(&*any_array_ptr);
}
else if (is_left)
{
for (const auto & name : columns)
{
auto src_col = block.getByName(name);
Block tmp_block{src_col, {{}, src_col.type, {}}};
function_builder->build({src_col})->execute(tmp_block, {0}, 1, src_col.column->size(), dry_run);
non_empty_array_columns[name] = tmp_block.safeGetByPosition(1).column;
}
any_array_ptr = non_empty_array_columns.begin()->second->convertToFullColumnIfConst();
any_array = &typeid_cast<const ColumnArray &>(*any_array_ptr);
}
size_t num_columns = block.columns();
for (size_t i = 0; i < num_columns; ++i)
{
ColumnWithTypeAndName & current = block.safeGetByPosition(i);
if (columns.count(current.name))
{
if (!typeid_cast<const DataTypeArray *>(&*current.type))
throw Exception("ARRAY JOIN of not array: " + current.name, ErrorCodes::TYPE_MISMATCH);
ColumnPtr array_ptr = (is_left && !is_unaligned) ? non_empty_array_columns[current.name] : current.column;
array_ptr = array_ptr->convertToFullColumnIfConst();
const ColumnArray & array = typeid_cast<const ColumnArray &>(*array_ptr);
if (!is_unaligned && !array.hasEqualOffsets(typeid_cast<const ColumnArray &>(*any_array_ptr)))
throw Exception("Sizes of ARRAY-JOIN-ed arrays do not match", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
current.column = typeid_cast<const ColumnArray &>(*array_ptr).getDataPtr();
current.type = typeid_cast<const DataTypeArray &>(*current.type).getNestedType();
}
else
{
current.column = current.column->replicate(any_array->getOffsets());
}
}
}
void ArrayJoinAction::finalize(NameSet & needed_columns, NameSet & unmodified_columns, NameSet & final_columns)
{
/// Do not ARRAY JOIN columns that are not used anymore.
/// Usually, such columns are not used until ARRAY JOIN, and therefore are ejected further in this function.
/// We will not remove all the columns so as not to lose the number of rows.
for (auto it = columns.begin(); it != columns.end();)
{
bool need = needed_columns.count(*it);
if (!need && columns.size() > 1)
{
columns.erase(it++);
}
else
{
needed_columns.insert(*it);
unmodified_columns.erase(*it);
/// If no ARRAY JOIN results are used, forcibly leave an arbitrary column at the output,
/// so you do not lose the number of rows.
if (!need)
final_columns.insert(*it);
++it;
}
}
}
}
#pragma once
#include <Core/Names.h>
#include <Core/Block.h>
namespace DB
{
class Context;
class IFunctionOverloadResolver;
using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>;
struct ArrayJoinAction
{
NameSet columns;
bool is_left = false;
bool is_unaligned = false;
/// For unaligned [LEFT] ARRAY JOIN
FunctionOverloadResolverPtr function_length;
FunctionOverloadResolverPtr function_greatest;
FunctionOverloadResolverPtr function_arrayResize;
/// For LEFT ARRAY JOIN.
FunctionOverloadResolverPtr function_builder;
ArrayJoinAction(const NameSet & array_joined_columns_, bool array_join_is_left, const Context & context);
void prepare(Block & sample_block);
void execute(Block & block, bool dry_run);
void finalize(NameSet & needed_columns, NameSet & unmodified_columns, NameSet & final_columns);
};
}
......@@ -6,13 +6,11 @@
#include <Interpreters/ExpressionJIT.h>
#include <Interpreters/AnalyzedJoin.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnArray.h>
#include <Common/typeid_cast.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Functions/IFunction.h>
#include <set>
#include <optional>
#include <Columns/ColumnSet.h>
#include <Functions/FunctionHelpers.h>
......@@ -33,20 +31,20 @@ namespace ErrorCodes
extern const int UNKNOWN_IDENTIFIER;
extern const int UNKNOWN_ACTION;
extern const int NOT_FOUND_COLUMN_IN_BLOCK;
extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
extern const int TOO_MANY_TEMPORARY_COLUMNS;
extern const int TOO_MANY_TEMPORARY_NON_CONST_COLUMNS;
extern const int TYPE_MISMATCH;
}
/// Read comment near usage
static constexpr auto DUMMY_COLUMN_NAME = "_dummy";
Names ExpressionAction::getNeededColumns() const
{
Names res = argument_names;
res.insert(res.end(), array_joined_columns.begin(), array_joined_columns.end());
if (array_join)
res.insert(res.end(), array_join->columns.begin(), array_join->columns.end());
if (table_join)
res.insert(res.end(), table_join->keyNamesLeft().begin(), table_join->keyNamesLeft().end());
......@@ -143,23 +141,9 @@ ExpressionAction ExpressionAction::addAliases(const NamesWithAliases & aliased_c
ExpressionAction ExpressionAction::arrayJoin(const NameSet & array_joined_columns, bool array_join_is_left, const Context & context)
{
if (array_joined_columns.empty())
throw Exception("No arrays to join", ErrorCodes::LOGICAL_ERROR);
ExpressionAction a;
a.type = ARRAY_JOIN;
a.array_joined_columns = array_joined_columns;
a.array_join_is_left = array_join_is_left;
a.unaligned_array_join = context.getSettingsRef().enable_unaligned_array_join;
if (a.unaligned_array_join)
{
a.function_length = FunctionFactory::instance().get("length", context);
a.function_greatest = FunctionFactory::instance().get("greatest", context);
a.function_arrayResize = FunctionFactory::instance().get("arrayResize", context);
}
else if (array_join_is_left)
a.function_builder = FunctionFactory::instance().get("emptyArrayToSingle", context);
a.array_join = std::make_shared<ArrayJoinAction>(array_joined_columns, array_join_is_left, context);
return a;
}
......@@ -172,7 +156,6 @@ ExpressionAction ExpressionAction::ordinaryJoin(std::shared_ptr<AnalyzedJoin> ta
return a;
}
void ExpressionAction::prepare(Block & sample_block, const Settings & settings, NameSet & names_not_for_constant_folding)
{
// std::cerr << "preparing: " << toString() << std::endl;
......@@ -256,16 +239,7 @@ void ExpressionAction::prepare(Block & sample_block, const Settings & settings,
case ARRAY_JOIN:
{
for (const auto & name : array_joined_columns)
{
ColumnWithTypeAndName & current = sample_block.getByName(name);
const DataTypeArray * array_type = typeid_cast<const DataTypeArray *>(&*current.type);
if (!array_type)
throw Exception("ARRAY JOIN requires array argument", ErrorCodes::TYPE_MISMATCH);
current.type = array_type->getNestedType();
current.column = nullptr;
}
array_join->prepare(sample_block);
break;
}
......@@ -383,95 +357,7 @@ void ExpressionAction::execute(Block & block, bool dry_run, ExtraBlockPtr & not_
case ARRAY_JOIN:
{
if (array_joined_columns.empty())
throw Exception("No arrays to join", ErrorCodes::LOGICAL_ERROR);
ColumnPtr any_array_ptr = block.getByName(*array_joined_columns.begin()).column->convertToFullColumnIfConst();
const ColumnArray * any_array = typeid_cast<const ColumnArray *>(&*any_array_ptr);
if (!any_array)
throw Exception("ARRAY JOIN of not array: " + *array_joined_columns.begin(), ErrorCodes::TYPE_MISMATCH);
/// If LEFT ARRAY JOIN, then we create columns in which empty arrays are replaced by arrays with one element - the default value.
std::map<String, ColumnPtr> non_empty_array_columns;
if (unaligned_array_join)
{
/// Resize all array joined columns to the longest one, (at least 1 if LEFT ARRAY JOIN), padded with default values.
auto rows = block.rows();
auto uint64 = std::make_shared<DataTypeUInt64>();
ColumnWithTypeAndName column_of_max_length;
if (array_join_is_left)
column_of_max_length = ColumnWithTypeAndName(uint64->createColumnConst(rows, 1u), uint64, {});
else
column_of_max_length = ColumnWithTypeAndName(uint64->createColumnConst(rows, 0u), uint64, {});
for (const auto & name : array_joined_columns)
{
auto & src_col = block.getByName(name);
Block tmp_block{src_col, {{}, uint64, {}}};
function_length->build({src_col})->execute(tmp_block, {0}, 1, rows);
Block tmp_block2{
column_of_max_length, tmp_block.safeGetByPosition(1), {{}, uint64, {}}};
function_greatest->build({column_of_max_length, tmp_block.safeGetByPosition(1)})->execute(tmp_block2, {0, 1}, 2, rows);
column_of_max_length = tmp_block2.safeGetByPosition(2);
}
for (const auto & name : array_joined_columns)
{
auto & src_col = block.getByName(name);
Block tmp_block{src_col, column_of_max_length, {{}, src_col.type, {}}};
function_arrayResize->build({src_col, column_of_max_length})->execute(tmp_block, {0, 1}, 2, rows);
src_col.column = tmp_block.safeGetByPosition(2).column;
any_array_ptr = src_col.column->convertToFullColumnIfConst();
}
any_array = typeid_cast<const ColumnArray *>(&*any_array_ptr);
}
else if (array_join_is_left)
{
for (const auto & name : array_joined_columns)
{
auto src_col = block.getByName(name);
Block tmp_block{src_col, {{}, src_col.type, {}}};
function_builder->build({src_col})->execute(tmp_block, {0}, 1, src_col.column->size(), dry_run);
non_empty_array_columns[name] = tmp_block.safeGetByPosition(1).column;
}
any_array_ptr = non_empty_array_columns.begin()->second->convertToFullColumnIfConst();
any_array = &typeid_cast<const ColumnArray &>(*any_array_ptr);
}
size_t columns = block.columns();
for (size_t i = 0; i < columns; ++i)
{
ColumnWithTypeAndName & current = block.safeGetByPosition(i);
if (array_joined_columns.count(current.name))
{
if (!typeid_cast<const DataTypeArray *>(&*current.type))
throw Exception("ARRAY JOIN of not array: " + current.name, ErrorCodes::TYPE_MISMATCH);
ColumnPtr array_ptr = (array_join_is_left && !unaligned_array_join) ? non_empty_array_columns[current.name] : current.column;
array_ptr = array_ptr->convertToFullColumnIfConst();
const ColumnArray & array = typeid_cast<const ColumnArray &>(*array_ptr);
if (!unaligned_array_join && !array.hasEqualOffsets(typeid_cast<const ColumnArray &>(*any_array_ptr)))
throw Exception("Sizes of ARRAY-JOIN-ed arrays do not match", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
current.column = typeid_cast<const ColumnArray &>(*array_ptr).getDataPtr();
current.type = typeid_cast<const DataTypeArray &>(*current.type).getNestedType();
}
else
{
current.column = current.column->replicate(any_array->getOffsets());
}
}
array_join->execute(block, dry_run);
break;
}
......@@ -539,7 +425,6 @@ void ExpressionAction::execute(Block & block, bool dry_run, ExtraBlockPtr & not_
}
}
void ExpressionAction::executeOnTotals(Block & block) const
{
if (type != JOIN)
......@@ -584,10 +469,10 @@ std::string ExpressionAction::toString() const
break;
case ARRAY_JOIN:
ss << (array_join_is_left ? "LEFT " : "") << "ARRAY JOIN ";
for (NameSet::const_iterator it = array_joined_columns.begin(); it != array_joined_columns.end(); ++it)
ss << (array_join->is_left ? "LEFT " : "") << "ARRAY JOIN ";
for (NameSet::const_iterator it = array_join->columns.begin(); it != array_join->columns.end(); ++it)
{
if (it != array_joined_columns.begin())
if (it != array_join->columns.begin())
ss << ", ";
ss << *it;
}
......@@ -675,7 +560,9 @@ void ExpressionActions::addImpl(ExpressionAction action, Names & new_names)
{
if (action.result_name != "")
new_names.push_back(action.result_name);
new_names.insert(new_names.end(), action.array_joined_columns.begin(), action.array_joined_columns.end());
if (action.array_join)
new_names.insert(new_names.end(), action.array_join->columns.begin(), action.array_join->columns.end());
/// Compiled functions are custom functions and they don't need building
if (action.type == ExpressionAction::APPLY_FUNCTION && !action.is_function_compiled)
......@@ -713,7 +600,7 @@ void ExpressionActions::prependArrayJoin(const ExpressionAction & action, const
if (action.type != ExpressionAction::ARRAY_JOIN)
throw Exception("ARRAY_JOIN action expected", ErrorCodes::LOGICAL_ERROR);
NameSet array_join_set(action.array_joined_columns.begin(), action.array_joined_columns.end());
NameSet array_join_set(action.array_join->columns.begin(), action.array_join->columns.end());
for (auto & it : input_columns)
{
if (array_join_set.count(it.name))
......@@ -738,12 +625,12 @@ bool ExpressionActions::popUnusedArrayJoin(const Names & required_columns, Expre
if (actions.empty() || actions.back().type != ExpressionAction::ARRAY_JOIN)
return false;
NameSet required_set(required_columns.begin(), required_columns.end());
for (const std::string & name : actions.back().array_joined_columns)
for (const std::string & name : actions.back().array_join->columns)
{
if (required_set.count(name))
return false;
}
for (const std::string & name : actions.back().array_joined_columns)
for (const std::string & name : actions.back().array_join->columns)
{
DataTypePtr & type = sample_block.getByName(name).type;
type = std::make_shared<DataTypeArray>(type);
......@@ -884,29 +771,7 @@ void ExpressionActions::finalize(const Names & output_columns)
}
else if (action.type == ExpressionAction::ARRAY_JOIN)
{
/// Do not ARRAY JOIN columns that are not used anymore.
/// Usually, such columns are not used until ARRAY JOIN, and therefore are ejected further in this function.
/// We will not remove all the columns so as not to lose the number of rows.
for (auto it = action.array_joined_columns.begin(); it != action.array_joined_columns.end();)
{
bool need = needed_columns.count(*it);
if (!need && action.array_joined_columns.size() > 1)
{
action.array_joined_columns.erase(it++);
}
else
{
needed_columns.insert(*it);
unmodified_columns.erase(*it);
/// If no ARRAY JOIN results are used, forcibly leave an arbitrary column at the output,
/// so you do not lose the number of rows.
if (!need)
final_columns.insert(*it);
++it;
}
}
action.array_join->finalize(needed_columns, unmodified_columns, final_columns);
}
else
{
......@@ -1143,7 +1008,8 @@ void ExpressionActions::optimizeArrayJoin()
if (actions[i].result_name != "")
array_joined_columns.insert(actions[i].result_name);
array_joined_columns.insert(actions[i].array_joined_columns.begin(), actions[i].array_joined_columns.end());
if (actions[i].array_join)
array_joined_columns.insert(actions[i].array_join->columns.begin(), actions[i].array_join->columns.end());
array_join_dependencies.insert(needed.begin(), needed.end());
}
......@@ -1274,8 +1140,8 @@ UInt128 ExpressionAction::ActionHash::operator()(const ExpressionAction & action
hash.update(arg_name);
break;
case ARRAY_JOIN:
hash.update(action.array_join_is_left);
for (const auto & col : action.array_joined_columns)
hash.update(action.array_join->is_left);
for (const auto & col : action.array_join->columns)
hash.update(col);
break;
case JOIN:
......@@ -1332,11 +1198,15 @@ bool ExpressionAction::operator==(const ExpressionAction & other) const
return false;
}
bool same_array_join = !array_join && !other.array_join;
if (array_join && other.array_join)
same_array_join = (array_join->columns == other.array_join->columns) &&
(array_join->is_left == other.array_join->is_left);
return source_name == other.source_name
&& result_name == other.result_name
&& argument_names == other.argument_names
&& array_joined_columns == other.array_joined_columns
&& array_join_is_left == other.array_join_is_left
&& same_array_join
&& AnalyzedJoin::sameJoin(table_join.get(), other.table_join.get())
&& projection == other.projection
&& is_function_compiled == other.is_function_compiled;
......
......@@ -11,6 +11,7 @@
#include <unordered_map>
#include <unordered_set>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Interpreters/ArrayJoinAction.h>
namespace DB
......@@ -81,15 +82,10 @@ public:
/// For ADD_COLUMN.
ColumnPtr added_column;
/// For APPLY_FUNCTION and LEFT ARRAY JOIN.
/// For APPLY_FUNCTION.
/// OverloadResolver is used before action was added to ExpressionActions (when we don't know types of arguments).
FunctionOverloadResolverPtr function_builder;
/// For unaligned [LEFT] ARRAY JOIN
FunctionOverloadResolverPtr function_length;
FunctionOverloadResolverPtr function_greatest;
FunctionOverloadResolverPtr function_arrayResize;
/// Can be used after action was added to ExpressionActions if we want to get function signature or properties like monotonicity.
FunctionBasePtr function_base;
/// Prepared function which is used in function execution.
......@@ -97,10 +93,8 @@ public:
Names argument_names;
bool is_function_compiled = false;
/// For ARRAY_JOIN
NameSet array_joined_columns;
bool array_join_is_left = false;
bool unaligned_array_join = false;
/// For ARRAY JOIN
std::shared_ptr<ArrayJoinAction> array_join;
/// For JOIN
std::shared_ptr<const AnalyzedJoin> table_join;
......
......@@ -322,7 +322,7 @@ ColumnsDescription InterpreterCreateQuery::getColumnsDescription(const ASTExpres
{
auto syntax_analyzer_result = SyntaxAnalyzer(context).analyze(default_expr_list, column_names_and_types);
const auto actions = ExpressionAnalyzer(default_expr_list, syntax_analyzer_result, context).getActions(true);
for (auto action : actions->getActions())
for (auto & action : actions->getActions())
if (action.type == ExpressionAction::Type::JOIN || action.type == ExpressionAction::Type::ARRAY_JOIN)
throw Exception("Cannot CREATE table. Unsupported default value that requires ARRAY JOIN or JOIN action", ErrorCodes::THERE_IS_NO_DEFAULT_VALUE);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册