未验证 提交 7af621f6 编写于 作者: A alexey-milovidov 提交者: GitHub

Merge pull request #3728 from amosbird/join

Versatile StorageJoin
#include <Functions/FunctionJoinGet.h>
#include <Functions/FunctionHelpers.h>
#include <Interpreters/Context.h>
#include <Interpreters/Join.h>
#include <Storages/StorageJoin.h>
#include <Functions/FunctionFactory.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
FunctionBasePtr FunctionBuilderJoinGet::buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const
{
if (arguments.size() != 3)
throw Exception{"Function " + getName() + " takes 3 arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
String join_name;
if (auto name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get()))
{
join_name = name_col->getValue<String>();
}
else
throw Exception{"Illegal type " + arguments[0].type->getName() + " of first argument of function " + getName()
+ ", expected a const string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
auto table = context.getTable("", join_name);
StorageJoin * storage_join = dynamic_cast<StorageJoin *>(table.get());
if (!storage_join)
throw Exception{"Table " + join_name + " should have engine StorageJoin", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
auto join = storage_join->getJoin();
String attr_name;
if (auto name_col = checkAndGetColumnConst<ColumnString>(arguments[1].column.get()))
{
attr_name = name_col->getValue<String>();
}
else
throw Exception{"Illegal type " + arguments[1].type->getName() + " of second argument of function " + getName()
+ ", expected a const string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
DataTypes data_types(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i)
data_types[i] = arguments[i].type;
return std::make_shared<DefaultFunction>(
std::make_shared<FunctionJoinGet>(join, attr_name), data_types, join->joinGetReturnType(attr_name));
}
void FunctionJoinGet::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/)
{
auto & ctn = block.getByPosition(arguments[2]);
ctn.name = ""; // make sure the key name never collide with the join columns
Block key_block = {ctn};
join->joinGet(key_block, attr_name);
block.getByPosition(result) = key_block.getByPosition(1);
}
void registerFunctionJoinGet(FunctionFactory & factory)
{
factory.registerFunction<FunctionBuilderJoinGet>();
}
}
#include <Functions/IFunction.h>
namespace DB
{
class Context;
class Join;
using JoinPtr = std::shared_ptr<Join>;
class FunctionJoinGet final : public IFunction, public std::enable_shared_from_this<FunctionJoinGet>
{
public:
static constexpr auto name = "joinGet";
FunctionJoinGet(JoinPtr join, const String & attr_name) : join(std::move(join)), attr_name(attr_name) {}
String getName() const override { return name; }
protected:
DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const override { return nullptr; }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
private:
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
private:
JoinPtr join;
const String attr_name;
};
class FunctionBuilderJoinGet final : public FunctionBuilderImpl
{
public:
static constexpr auto name = "joinGet";
static FunctionBuilderPtr create(const Context & context) { return std::make_shared<FunctionBuilderJoinGet>(context); }
FunctionBuilderJoinGet(const Context & context) : context(context) {}
String getName() const override { return name; }
protected:
FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const override;
DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const override { return nullptr; }
private:
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
private:
const Context & context;
};
}
......@@ -40,6 +40,7 @@ void registerFunctionToLowCardinality(FunctionFactory &);
void registerFunctionLowCardinalityIndices(FunctionFactory &);
void registerFunctionLowCardinalityKeys(FunctionFactory &);
void registerFunctionsIn(FunctionFactory &);
void registerFunctionJoinGet(FunctionFactory &);
void registerFunctionsMiscellaneous(FunctionFactory & factory)
{
......@@ -80,6 +81,7 @@ void registerFunctionsMiscellaneous(FunctionFactory & factory)
registerFunctionLowCardinalityIndices(factory);
registerFunctionLowCardinalityKeys(factory);
registerFunctionsIn(factory);
registerFunctionJoinGet(factory);
}
}
......@@ -150,15 +150,18 @@ ExpressionAction ExpressionAction::arrayJoin(const NameSet & array_joined_column
return a;
}
ExpressionAction ExpressionAction::ordinaryJoin(std::shared_ptr<const Join> join_,
const Names & join_key_names_left,
const NamesAndTypesList & columns_added_by_join_)
ExpressionAction ExpressionAction::ordinaryJoin(
std::shared_ptr<const Join> join_,
const Names & join_key_names_left,
const NamesAndTypesList & columns_added_by_join_,
const NameSet & columns_added_by_join_from_right_keys_)
{
ExpressionAction a;
a.type = JOIN;
a.join = std::move(join_);
a.join_key_names_left = join_key_names_left;
a.columns_added_by_join = columns_added_by_join_;
a.columns_added_by_join_from_right_keys = columns_added_by_join_from_right_keys_;
return a;
}
......@@ -427,7 +430,7 @@ void ExpressionAction::execute(Block & block, bool dry_run) const
case JOIN:
{
join->joinBlock(block);
join->joinBlock(block, join_key_names_left, columns_added_by_join_from_right_keys);
break;
}
......@@ -1085,7 +1088,7 @@ BlockInputStreamPtr ExpressionActions::createStreamWithNonJoinedDataIfFullOrRigh
{
for (const auto & action : actions)
if (action.join && (action.join->getKind() == ASTTableJoin::Kind::Full || action.join->getKind() == ASTTableJoin::Kind::Right))
return action.join->createStreamWithNonJoinedRows(source_header, max_block_size);
return action.join->createStreamWithNonJoinedRows(source_header, action.join_key_names_left, max_block_size);
return {};
}
......
......@@ -102,6 +102,7 @@ public:
std::shared_ptr<const Join> join;
Names join_key_names_left;
NamesAndTypesList columns_added_by_join;
NameSet columns_added_by_join_from_right_keys;
/// For PROJECT.
NamesWithAliases projection;
......@@ -118,7 +119,7 @@ public:
static ExpressionAction addAliases(const NamesWithAliases & aliased_columns_);
static ExpressionAction arrayJoin(const NameSet & array_joined_columns, bool array_join_is_left, const Context & context);
static ExpressionAction ordinaryJoin(std::shared_ptr<const Join> join_, const Names & join_key_names_left,
const NamesAndTypesList & columns_added_by_join_);
const NamesAndTypesList & columns_added_by_join_, const NameSet & columns_added_by_join_from_right_keys_);
/// Which columns necessary to perform this action.
Names getNeededColumns() const;
......
......@@ -552,12 +552,13 @@ void ExpressionAnalyzer::addJoinAction(ExpressionActionsPtr & actions, bool only
columns_added_by_join_list.push_back(joined_column.name_and_type);
if (only_types)
actions->add(ExpressionAction::ordinaryJoin(nullptr, analyzedJoin().key_names_left, columns_added_by_join_list));
actions->add(ExpressionAction::ordinaryJoin(nullptr, analyzedJoin().key_names_left,
columns_added_by_join_list, columns_added_by_join_from_right_keys));
else
for (auto & subquery_for_set : subqueries_for_sets)
if (subquery_for_set.second.join)
actions->add(ExpressionAction::ordinaryJoin(subquery_for_set.second.join, analyzedJoin().key_names_left,
columns_added_by_join_list));
columns_added_by_join_list, columns_added_by_join_from_right_keys));
}
bool ExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_types)
......@@ -617,10 +618,8 @@ bool ExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_ty
if (!subquery_for_set.join)
{
JoinPtr join = std::make_shared<Join>(
analyzedJoin().key_names_left, analyzedJoin().key_names_right, columns_added_by_join_from_right_keys,
settings.join_use_nulls, settings.size_limits_for_join,
join_params.kind, join_params.strictness);
JoinPtr join = std::make_shared<Join>(analyzedJoin().key_names_right, settings.join_use_nulls,
settings.size_limits_for_join, join_params.kind, join_params.strictness);
/** For GLOBAL JOINs (in the case, for example, of the push method for executing GLOBAL subqueries), the following occurs
* - in the addExternalStorage function, the JOIN (SELECT ...) subquery is replaced with JOIN _data1,
......
......@@ -31,12 +31,10 @@ namespace ErrorCodes
}
Join::Join(const Names & key_names_left_, const Names & key_names_right_, const NameSet & needed_key_names_right_,
bool use_nulls_, const SizeLimits & limits, ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_)
Join::Join(const Names & key_names_right_, bool use_nulls_, const SizeLimits & limits,
ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_)
: kind(kind_), strictness(strictness_),
key_names_left(key_names_left_),
key_names_right(key_names_right_),
needed_key_names_right(needed_key_names_right_),
use_nulls(use_nulls_),
log(&Logger::get("Join")),
limits(limits)
......@@ -670,7 +668,12 @@ namespace
template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename Maps>
void Join::joinBlockImpl(Block & block, const Maps & maps) const
void Join::joinBlockImpl(
Block & block,
const Names & key_names_left,
const NameSet & needed_key_names_right,
const Block & block_with_columns_to_add,
const Maps & maps) const
{
size_t keys_size = key_names_left.size();
ColumnRawPtrs key_columns(keys_size);
......@@ -744,8 +747,8 @@ void Join::joinBlockImpl(Block & block, const Maps & maps) const
{
const ColumnWithTypeAndName & src_column = sample_block_with_columns_to_add.safeGetByPosition(i);
/// Don't insert column if it's in left block.
if (!block.has(src_column.name))
/// Don't insert column if it's in left block or not explicitly required.
if (!block.has(src_column.name) && block_with_columns_to_add.has(src_column.name))
{
added_columns.push_back(src_column.column->cloneEmpty());
added_columns.back()->reserve(src_column.column->size());
......@@ -756,7 +759,6 @@ void Join::joinBlockImpl(Block & block, const Maps & maps) const
size_t rows = block.rows();
/// Used with ANY INNER JOIN
std::unique_ptr<IColumn::Filter> filter;
bool filter_left_keys = (kind == ASTTableJoin::Kind::Inner || kind == ASTTableJoin::Kind::Right) && strictness == ASTTableJoin::Strictness::Any;
......@@ -885,7 +887,7 @@ void Join::joinBlockImplCross(Block & block) const
}
void Join::checkTypesOfKeys(const Block & block_left, const Block & block_right) const
void Join::checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right) const
{
size_t keys_size = key_names_left.size();
......@@ -905,30 +907,90 @@ void Join::checkTypesOfKeys(const Block & block_left, const Block & block_right)
}
void Join::joinBlock(Block & block) const
static void checkTypeOfKey(const Block & block_left, const Block & block_right)
{
auto & [c1, left_type_origin, left_name] = block_left.safeGetByPosition(0);
auto & [c2, right_type_origin, right_name] = block_right.safeGetByPosition(0);
auto left_type = removeNullable(left_type_origin);
auto right_type = removeNullable(right_type_origin);
if (!left_type->equals(*right_type))
throw Exception("Type mismatch of columns to joinGet by: "
+ left_name + " " + left_type->getName() + " at left, "
+ right_name + " " + right_type->getName() + " at right",
ErrorCodes::TYPE_MISMATCH);
}
DataTypePtr Join::joinGetReturnType(const String & column_name) const
{
std::shared_lock lock(rwlock);
if (!sample_block_with_columns_to_add.has(column_name))
throw Exception("StorageJoin doesn't contain column " + column_name, ErrorCodes::LOGICAL_ERROR);
return sample_block_with_columns_to_add.getByName(column_name).type;
}
template <typename Maps>
void Join::joinGetImpl(Block & block, const String & column_name, const Maps & maps) const
{
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>(
block, {block.getByPosition(0).name}, {}, {sample_block_with_columns_to_add.getByName(column_name)}, maps);
}
// TODO: support composite key
// TODO: return multible columns as named tuple
// TODO: return array of values when strictness == ASTTableJoin::Strictness::All
void Join::joinGet(Block & block, const String & column_name) const
{
std::shared_lock lock(rwlock);
if (key_names_right.size() != 1)
throw Exception("joinGet only supports StorageJoin containing exactly one key", ErrorCodes::LOGICAL_ERROR);
checkTypeOfKey(block, sample_block_with_keys);
if (kind == ASTTableJoin::Kind::Left && strictness == ASTTableJoin::Strictness::Any)
joinGetImpl(block, column_name, maps_any);
else
throw Exception("joinGet only supports StorageJoin of type Left Any", ErrorCodes::LOGICAL_ERROR);
}
void Join::joinBlock(Block & block, const Names & key_names_left, const NameSet & needed_key_names_right) const
{
// std::cerr << "joinBlock: " << block.dumpStructure() << "\n";
std::shared_lock lock(rwlock);
checkTypesOfKeys(block, sample_block_with_keys);
checkTypesOfKeys(block, key_names_left, sample_block_with_keys);
if (kind == ASTTableJoin::Kind::Left && strictness == ASTTableJoin::Strictness::Any)
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>(block, maps_any);
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>(
block, key_names_left, needed_key_names_right, sample_block_with_columns_to_add, maps_any);
else if (kind == ASTTableJoin::Kind::Inner && strictness == ASTTableJoin::Strictness::Any)
joinBlockImpl<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::Any>(block, maps_any);
joinBlockImpl<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::Any>(
block, key_names_left, needed_key_names_right, sample_block_with_columns_to_add, maps_any);
else if (kind == ASTTableJoin::Kind::Left && strictness == ASTTableJoin::Strictness::All)
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::All>(block, maps_all);
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::All>(
block, key_names_left, needed_key_names_right, sample_block_with_columns_to_add, maps_all);
else if (kind == ASTTableJoin::Kind::Inner && strictness == ASTTableJoin::Strictness::All)
joinBlockImpl<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::All>(block, maps_all);
joinBlockImpl<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::All>(
block, key_names_left, needed_key_names_right, sample_block_with_columns_to_add, maps_all);
else if (kind == ASTTableJoin::Kind::Full && strictness == ASTTableJoin::Strictness::Any)
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>(block, maps_any_full);
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>(
block, key_names_left, needed_key_names_right, sample_block_with_columns_to_add, maps_any_full);
else if (kind == ASTTableJoin::Kind::Right && strictness == ASTTableJoin::Strictness::Any)
joinBlockImpl<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::Any>(block, maps_any_full);
joinBlockImpl<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::Any>(
block, key_names_left, needed_key_names_right, sample_block_with_columns_to_add, maps_any_full);
else if (kind == ASTTableJoin::Kind::Full && strictness == ASTTableJoin::Strictness::All)
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::All>(block, maps_all_full);
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::All>(
block, key_names_left, needed_key_names_right, sample_block_with_columns_to_add, maps_all_full);
else if (kind == ASTTableJoin::Kind::Right && strictness == ASTTableJoin::Strictness::All)
joinBlockImpl<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::All>(block, maps_all_full);
joinBlockImpl<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::All>(
block, key_names_left, needed_key_names_right, sample_block_with_columns_to_add, maps_all_full);
else if (kind == ASTTableJoin::Kind::Cross)
joinBlockImplCross(block);
else
......@@ -1009,14 +1071,14 @@ struct AdderNonJoined<ASTTableJoin::Strictness::All, Mapped>
class NonJoinedBlockInputStream : public IProfilingBlockInputStream
{
public:
NonJoinedBlockInputStream(const Join & parent_, const Block & left_sample_block, size_t max_block_size_)
NonJoinedBlockInputStream(const Join & parent_, const Block & left_sample_block, const Names & key_names_left, size_t max_block_size_)
: parent(parent_), max_block_size(max_block_size_)
{
/** left_sample_block contains keys and "left" columns.
* result_sample_block - keys, "left" columns, and "right" columns.
*/
size_t num_keys = parent.key_names_left.size();
size_t num_keys = key_names_left.size();
size_t num_columns_left = left_sample_block.columns() - num_keys;
size_t num_columns_right = parent.sample_block_with_columns_to_add.columns();
......@@ -1033,7 +1095,7 @@ public:
column_indices_keys_and_right.reserve(num_keys + num_columns_right);
std::vector<bool> is_key_column_in_left_block(num_keys + num_columns_left, false);
for (const std::string & key : parent.key_names_left)
for (const std::string & key : key_names_left)
{
size_t key_pos = left_sample_block.getPositionByName(key);
is_key_column_in_left_block[key_pos] = true;
......@@ -1183,9 +1245,9 @@ private:
};
BlockInputStreamPtr Join::createStreamWithNonJoinedRows(const Block & left_sample_block, size_t max_block_size) const
BlockInputStreamPtr Join::createStreamWithNonJoinedRows(const Block & left_sample_block, const Names & key_names_left, size_t max_block_size) const
{
return std::make_shared<NonJoinedBlockInputStream>(*this, left_sample_block, max_block_size);
return std::make_shared<NonJoinedBlockInputStream>(*this, left_sample_block, key_names_left, max_block_size);
}
......
......@@ -219,8 +219,8 @@ struct JoinKeyGetterHashed
class Join
{
public:
Join(const Names & key_names_left_, const Names & key_names_right_, const NameSet & needed_key_names_right_,
bool use_nulls_, const SizeLimits & limits, ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_);
Join(const Names & key_names_right_, bool use_nulls_, const SizeLimits & limits,
ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_);
bool empty() { return type == Type::EMPTY; }
......@@ -237,7 +237,13 @@ public:
/** Join data from the map (that was previously built by calls to insertFromBlock) to the block with data from "left" table.
* Could be called from different threads in parallel.
*/
void joinBlock(Block & block) const;
void joinBlock(Block & block, const Names & key_names_left, const NameSet & needed_key_names_right) const;
/// Infer the return type for joinGet function
DataTypePtr joinGetReturnType(const String & column_name) const;
/// Used by joinGet function that turns StorageJoin into a dictionary
void joinGet(Block & block, const String & column_name) const;
/** Keep "totals" (separate part of dataset, see WITH TOTALS) to use later.
*/
......@@ -251,7 +257,7 @@ public:
* Use only after all calls to joinBlock was done.
* left_sample_block is passed without account of 'use_nulls' setting (columns will be converted to Nullable inside).
*/
BlockInputStreamPtr createStreamWithNonJoinedRows(const Block & left_sample_block, size_t max_block_size) const;
BlockInputStreamPtr createStreamWithNonJoinedRows(const Block & left_sample_block, const Names & key_names_left, size_t max_block_size) const;
/// Number of keys in all built JOIN maps.
size_t getTotalRowCount() const;
......@@ -320,6 +326,16 @@ public:
M(keys256) \
M(hashed)
/// Used for reading from StorageJoin and applying joinGet function
#define APPLY_FOR_JOIN_VARIANTS_LIMITED(M) \
M(key8) \
M(key16) \
M(key32) \
M(key64) \
M(key_string) \
M(key_fixed_string)
enum class Type
{
EMPTY,
......@@ -353,16 +369,13 @@ public:
private:
friend class NonJoinedBlockInputStream;
friend class JoinBlockInputStream;
ASTTableJoin::Kind kind;
ASTTableJoin::Strictness strictness;
/// Names of key columns (columns for equi-JOIN) in "left" table (in the order they appear in USING clause).
const Names key_names_left;
/// Names of key columns (columns for equi-JOIN) in "right" table (in the order they appear in USING clause).
const Names key_names_right;
/// Names of key columns in the "right" table which should stay in block after join.
const NameSet needed_key_names_right;
/// Substitute NULLs for non-JOINed rows.
bool use_nulls;
......@@ -408,12 +421,20 @@ private:
void init(Type type_);
/// Throw an exception if blocks have different types of key columns.
void checkTypesOfKeys(const Block & block_left, const Block & block_right) const;
void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right) const;
template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename Maps>
void joinBlockImpl(Block & block, const Maps & maps) const;
void joinBlockImpl(
Block & block,
const Names & key_names_left,
const NameSet & needed_key_names_right,
const Block & block_with_columns_to_add,
const Maps & maps) const;
void joinBlockImplCross(Block & block) const;
template <typename Maps>
void joinGetImpl(Block & block, const String & column_name, const Maps & maps) const;
};
using JoinPtr = std::shared_ptr<Join>;
......
......@@ -86,11 +86,11 @@ StoragePtr StorageFactory::get(
name = engine_def.name;
if (storage_def->settings && !endsWith(name, "MergeTree") && name != "Kafka")
if (storage_def->settings && !endsWith(name, "MergeTree") && name != "Kafka" && name != "Join")
{
throw Exception(
"Engine " + name + " doesn't support SETTINGS clause. "
"Currently only the MergeTree family of engines and Kafka engine supports it",
"Currently only the MergeTree family of engines, Kafka engine and Join engine support it",
ErrorCodes::BAD_ARGUMENTS);
}
......
#include <Storages/StorageJoin.h>
#include <Storages/StorageFactory.h>
#include <Interpreters/Join.h>
#include <Parsers/ASTCreateQuery.h>
#include <Parsers/ASTIdentifier.h>
#include <Common/typeid_cast.h>
#include <Core/ColumnNumbers.h>
#include <DataStreams/IProfilingBlockInputStream.h>
#include <DataTypes/NestedUtils.h>
#include <Poco/String.h> /// toLower
#include <Poco/File.h>
......@@ -13,6 +17,7 @@ namespace DB
namespace ErrorCodes
{
extern const int UNKNOWN_SET_DATA_VARIANT;
extern const int NO_SUCH_COLUMN_IN_TABLE;
extern const int INCOMPATIBLE_TYPE_OF_JOIN;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
......@@ -24,18 +29,23 @@ StorageJoin::StorageJoin(
const String & path_,
const String & name_,
const Names & key_names_,
ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_,
bool use_nulls_,
SizeLimits limits_,
ASTTableJoin::Kind kind_,
ASTTableJoin::Strictness strictness_,
const ColumnsDescription & columns_)
: StorageSetOrJoinBase{path_, name_, columns_},
key_names(key_names_), kind(kind_), strictness(strictness_)
: StorageSetOrJoinBase{path_, name_, columns_}
, key_names(key_names_)
, use_nulls(use_nulls_)
, limits(limits_)
, kind(kind_)
, strictness(strictness_)
{
for (const auto & key : key_names)
if (!getColumns().hasPhysical(key))
throw Exception{"Key column (" + key + ") does not exist in table declaration.", ErrorCodes::NO_SUCH_COLUMN_IN_TABLE};
/// NOTE StorageJoin doesn't use join_use_nulls setting.
join = std::make_shared<Join>(key_names, key_names, NameSet(), false /* use_nulls */, SizeLimits(), kind, strictness);
join = std::make_shared<Join>(key_names, use_nulls, limits, kind, strictness);
join->setSampleBlock(getSampleBlock().sortColumns());
restore();
}
......@@ -48,7 +58,7 @@ void StorageJoin::truncate(const ASTPtr &)
Poco::File(path + "tmp/").createDirectories();
increment = 0;
join = std::make_shared<Join>(key_names, key_names, NameSet(), false /* use_nulls */, SizeLimits(), kind, strictness);
join = std::make_shared<Join>(key_names, use_nulls, limits, kind, strictness);
join->setSampleBlock(getSampleBlock().sortColumns());
}
......@@ -119,11 +129,237 @@ void registerStorageJoin(StorageFactory & factory)
key_names.push_back(key->name);
}
auto & settings = args.context.getSettingsRef();
auto join_use_nulls = settings.join_use_nulls;
auto max_rows_in_join = settings.max_rows_in_join;
auto max_bytes_in_join = settings.max_bytes_in_join;
auto join_overflow_mode = settings.join_overflow_mode;
if (args.storage_def && args.storage_def->settings)
{
for (const ASTSetQuery::Change & setting : args.storage_def->settings->changes)
{
if (setting.name == "join_use_nulls") join_use_nulls.set(setting.value);
else if (setting.name == "max_rows_in_join") max_rows_in_join.set(setting.value);
else if (setting.name == "max_bytes_in_join") max_bytes_in_join.set(setting.value);
else if (setting.name == "join_overflow_mode") join_overflow_mode.set(setting.value);
else
throw Exception(
"Unknown setting " + setting.name + " for storage " + args.engine_name,
ErrorCodes::BAD_ARGUMENTS);
}
}
return StorageJoin::create(
args.data_path, args.table_name,
key_names, kind, strictness,
args.data_path,
args.table_name,
key_names,
join_use_nulls.value,
SizeLimits{max_rows_in_join.value, max_bytes_in_join.value, join_overflow_mode.value},
kind,
strictness,
args.columns);
});
}
template <typename T>
static const char * rawData(T & t)
{
return reinterpret_cast<const char *>(&t);
}
template <typename T>
static size_t rawSize(T &)
{
return sizeof(T);
}
template <>
const char * rawData(const StringRef & t)
{
return t.data;
}
template <>
size_t rawSize(const StringRef & t)
{
return t.size;
}
class JoinBlockInputStream : public IProfilingBlockInputStream
{
public:
JoinBlockInputStream(const Join & parent_, size_t max_block_size_, Block & sample_block_)
: parent(parent_), lock(parent.rwlock), max_block_size(max_block_size_), sample_block(sample_block_)
{
columns.resize(sample_block.columns());
column_indices.resize(sample_block.columns());
column_with_null.resize(sample_block.columns());
for (size_t i = 0; i < sample_block.columns(); ++i)
{
auto & [_, type, name] = sample_block.getByPosition(i);
if (parent.sample_block_with_keys.has(name))
{
key_pos = i;
column_with_null[i] = parent.sample_block_with_keys.getByName(name).type->isNullable();
}
else
{
auto pos = parent.sample_block_with_columns_to_add.getPositionByName(name);
column_indices[i] = pos;
column_with_null[i] = !parent.sample_block_with_columns_to_add.getByPosition(pos).type->equals(*type);
}
}
}
String getName() const override { return "Join"; }
Block getHeader() const override { return sample_block; }
protected:
Block readImpl() override
{
if (parent.blocks.empty())
return Block();
if (parent.strictness == ASTTableJoin::Strictness::Any)
return createBlock<ASTTableJoin::Strictness::Any>(parent.maps_any);
else if (parent.strictness == ASTTableJoin::Strictness::All)
return createBlock<ASTTableJoin::Strictness::All>(parent.maps_all);
else
throw Exception("Logical error: unknown JOIN strictness (must be ANY or ALL)", ErrorCodes::LOGICAL_ERROR);
}
private:
const Join & parent;
std::shared_lock<std::shared_mutex> lock;
size_t max_block_size;
Block sample_block;
ColumnNumbers column_indices;
std::vector<bool> column_with_null;
std::optional<size_t> key_pos;
MutableColumns columns;
std::unique_ptr<void, std::function<void(void *)>> position; /// type erasure
template <ASTTableJoin::Strictness STRICTNESS, typename Maps>
Block createBlock(const Maps & maps)
{
for (size_t i = 0; i < sample_block.columns(); ++i)
{
const auto & src_col = sample_block.safeGetByPosition(i);
columns[i] = src_col.type->createColumn();
if (column_with_null[i])
{
if (key_pos == i)
{
// unwrap null key column
ColumnNullable & nullable_col = static_cast<ColumnNullable &>(*columns[i]);
columns[i] = nullable_col.getNestedColumnPtr()->assumeMutable();
}
else
// wrap non key column with null
columns[i] = makeNullable(std::move(columns[i]))->assumeMutable();
}
}
size_t rows_added = 0;
switch (parent.type)
{
#define M(TYPE) \
case Join::Type::TYPE: \
rows_added = fillColumns<STRICTNESS>(*maps.TYPE); \
break;
APPLY_FOR_JOIN_VARIANTS_LIMITED(M)
#undef M
default:
throw Exception("Unknown JOIN keys variant for limited use", ErrorCodes::UNKNOWN_SET_DATA_VARIANT);
}
if (!rows_added)
return {};
Block res = sample_block.cloneEmpty();
for (size_t i = 0; i < columns.size(); ++i)
if (column_with_null[i])
{
if (key_pos == i)
res.getByPosition(i).column = makeNullable(std::move(columns[i]))->assumeMutable();
else
{
const ColumnNullable & nullable_col = static_cast<const ColumnNullable &>(*columns[i]);
res.getByPosition(i).column = nullable_col.getNestedColumnPtr();
}
}
else
res.getByPosition(i).column = std::move(columns[i]);
return res;
}
template <ASTTableJoin::Strictness STRICTNESS, typename Map>
size_t fillColumns(const Map & map)
{
size_t rows_added = 0;
if (!position)
position = decltype(position)(
static_cast<void *>(new typename Map::const_iterator(map.begin())),
[](void * ptr) { delete reinterpret_cast<typename Map::const_iterator *>(ptr); });
auto & it = *reinterpret_cast<typename Map::const_iterator *>(position.get());
auto end = map.end();
for (; it != end; ++it)
{
if constexpr (STRICTNESS == ASTTableJoin::Strictness::Any)
{
for (size_t j = 0; j < columns.size(); ++j)
if (j == key_pos)
columns[j]->insertData(rawData(it->first), rawSize(it->first));
else
columns[j]->insertFrom(*it->second.block->getByPosition(column_indices[j]).column.get(), it->second.row_num);
++rows_added;
}
else
for (auto current = &static_cast<const typename Map::mapped_type::Base_t &>(it->second); current != nullptr;
current = current->next)
{
for (size_t j = 0; j < columns.size(); ++j)
if (j == key_pos)
columns[j]->insertData(rawData(it->first), rawSize(it->first));
else
columns[j]->insertFrom(*current->block->getByPosition(column_indices[j]).column.get(), current->row_num);
++rows_added;
}
if (rows_added >= max_block_size)
{
++it;
break;
}
}
return rows_added;
}
};
// TODO: multiple stream read and index read
BlockInputStreams StorageJoin::read(
const Names & column_names,
const SelectQueryInfo & /*query_info*/,
const Context & /*context*/,
QueryProcessingStage::Enum /*processed_stage*/,
size_t max_block_size,
unsigned /*num_streams*/)
{
check(column_names);
Block sample_block = getSampleBlockForColumns(column_names);
return {std::make_shared<JoinBlockInputStream>(*join, max_block_size, sample_block)};
}
}
......@@ -33,8 +33,19 @@ public:
/// Verify that the data structure is suitable for implementing this type of JOIN.
void assertCompatible(ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_) const;
BlockInputStreams read(
const Names & column_names,
const SelectQueryInfo & query_info,
const Context & context,
QueryProcessingStage::Enum processed_stage,
size_t max_block_size,
unsigned num_streams) override;
private:
Block sample_block;
const Names & key_names;
bool use_nulls;
SizeLimits limits;
ASTTableJoin::Kind kind; /// LEFT | INNER ...
ASTTableJoin::Strictness strictness; /// ANY | ALL
......@@ -48,6 +59,8 @@ protected:
const String & path_,
const String & name_,
const Names & key_names_,
bool use_nulls_,
SizeLimits limits_,
ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_,
const ColumnsDescription & columns_);
};
......
--------read--------
def [1,2] 2
abc [0] 1
def [1,2] 2
abc [0] 1
def [1,2] 2
abc [0] 1
def [1,2] 2
abc [0] 1
--------joinGet--------
abc
def
\N
abc
def
[0] 1
DROP TABLE IF EXISTS test.join_any_inner;
DROP TABLE IF EXISTS test.join_any_left;
DROP TABLE IF EXISTS test.join_any_left_null;
DROP TABLE IF EXISTS test.join_all_inner;
DROP TABLE IF EXISTS test.join_all_left;
DROP TABLE IF EXISTS test.join_string_key;
CREATE TABLE test.join_any_inner (s String, x Array(UInt8), k UInt64) ENGINE = Join(ANY, INNER, k);
CREATE TABLE test.join_any_left (s String, x Array(UInt8), k UInt64) ENGINE = Join(ANY, LEFT, k);
CREATE TABLE test.join_all_inner (s String, x Array(UInt8), k UInt64) ENGINE = Join(ALL, INNER, k);
CREATE TABLE test.join_all_left (s String, x Array(UInt8), k UInt64) ENGINE = Join(ALL, LEFT, k);
USE test;
INSERT INTO test.join_any_inner VALUES ('abc', [0], 1), ('def', [1, 2], 2);
INSERT INTO test.join_any_left VALUES ('abc', [0], 1), ('def', [1, 2], 2);
INSERT INTO test.join_all_inner VALUES ('abc', [0], 1), ('def', [1, 2], 2);
INSERT INTO test.join_all_left VALUES ('abc', [0], 1), ('def', [1, 2], 2);
-- read from StorageJoin
SELECT '--------read--------';
SELECT * from test.join_any_inner;
SELECT * from test.join_any_left;
SELECT * from test.join_all_inner;
SELECT * from test.join_all_left;
-- create StorageJoin tables with customized settings
CREATE TABLE test.join_any_left_null (s String, k UInt64) ENGINE = Join(ANY, LEFT, k) SETTINGS join_use_nulls = 1;
INSERT INTO test.join_any_left_null VALUES ('abc', 1), ('def', 2);
-- joinGet
SELECT '--------joinGet--------';
SELECT joinGet('join_any_left', 's', number) FROM numbers(3);
SELECT '';
SELECT joinGet('join_any_left_null', 's', number) FROM numbers(3);
SELECT '';
CREATE TABLE test.join_string_key (s String, x Array(UInt8), k UInt64) ENGINE = Join(ANY, LEFT, s);
INSERT INTO test.join_string_key VALUES ('abc', [0], 1), ('def', [1, 2], 2);
SELECT joinGet('join_string_key', 'x', 'abc'), joinGet('join_string_key', 'k', 'abc');
USE default;
DROP TABLE test.join_any_inner;
DROP TABLE test.join_any_left;
DROP TABLE test.join_any_left_null;
DROP TABLE test.join_all_inner;
DROP TABLE test.join_all_left;
DROP TABLE test.join_string_key;
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册