提交 65b654a1 编写于 作者: A Amos Bird

Versatile StorageJoin

This commit does the following:

1. StorageJoin with simple keys now supports reading
2. StorageJoin can be created with Join settings applied. Syntax is
similar to MergeTree and Kafka
3. Left Any StorageJoin with one simple key can be used as a
dictionary-like structure by function joinGet.

Examples are listed in the related test file.
上级 1cc69100
#include <Functions/FunctionJoinGet.h>
#include <Functions/FunctionHelpers.h>
#include <Interpreters/Context.h>
#include <Interpreters/Join.h>
#include <Storages/StorageJoin.h>
#include <Functions/FunctionFactory.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
FunctionBasePtr FunctionBuilderJoinGet::buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const
{
if (arguments.size() != 3)
throw Exception{"Function " + getName() + " takes 3 arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
String join_name;
if (auto name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get()))
{
join_name = name_col->getValue<String>();
}
else
throw Exception{"Illegal type " + arguments[0].type->getName() + " of first argument of function " + getName()
+ ", expected a const string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
auto table = context.getTable("", join_name);
StorageJoin * storage_join = dynamic_cast<StorageJoin *>(table.get());
if (!storage_join)
throw Exception{"Table " + join_name + " should have engine StorageJoin", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
auto join = storage_join->getJoin();
String attr_name;
if (auto name_col = checkAndGetColumnConst<ColumnString>(arguments[1].column.get()))
{
attr_name = name_col->getValue<String>();
}
else
throw Exception{"Illegal type " + arguments[1].type->getName() + " of second argument of function " + getName()
+ ", expected a const string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
DataTypes data_types(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i)
data_types[i] = arguments[i].type;
return std::make_shared<DefaultFunction>(
std::make_shared<FunctionJoinGet>(join, attr_name), data_types, join->joinGetReturnType(attr_name));
}
void FunctionJoinGet::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/)
{
auto & ctn = block.getByPosition(arguments[2]);
ctn.name = ""; // make sure the key name never collide with the join columns
Block key_block = {ctn};
join->joinGet(key_block, attr_name);
block.getByPosition(result) = key_block.getByPosition(1);
}
void registerFunctionJoinGet(FunctionFactory & factory)
{
factory.registerFunction<FunctionBuilderJoinGet>();
}
}
#include <Functions/IFunction.h>
namespace DB
{
class Context;
class Join;
using JoinPtr = std::shared_ptr<Join>;
class FunctionJoinGet final : public IFunction, public std::enable_shared_from_this<FunctionJoinGet>
{
public:
static constexpr auto name = "joinGet";
FunctionJoinGet(JoinPtr join, const String & attr_name) : join(std::move(join)), attr_name(attr_name) {}
String getName() const override { return name; }
protected:
DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const override { return nullptr; }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
private:
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
private:
JoinPtr join;
const String attr_name;
};
class FunctionBuilderJoinGet final : public FunctionBuilderImpl
{
public:
static constexpr auto name = "joinGet";
static FunctionBuilderPtr create(const Context & context) { return std::make_shared<FunctionBuilderJoinGet>(context); }
FunctionBuilderJoinGet(const Context & context) : context(context) {}
String getName() const override { return name; }
protected:
FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const override;
DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const override { return nullptr; }
private:
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
private:
const Context & context;
};
}
......@@ -40,6 +40,7 @@ void registerFunctionToLowCardinality(FunctionFactory &);
void registerFunctionLowCardinalityIndices(FunctionFactory &);
void registerFunctionLowCardinalityKeys(FunctionFactory &);
void registerFunctionsIn(FunctionFactory &);
void registerFunctionJoinGet(FunctionFactory &);
void registerFunctionsMiscellaneous(FunctionFactory & factory)
{
......@@ -80,6 +81,7 @@ void registerFunctionsMiscellaneous(FunctionFactory & factory)
registerFunctionLowCardinalityIndices(factory);
registerFunctionLowCardinalityKeys(factory);
registerFunctionsIn(factory);
registerFunctionJoinGet(factory);
}
}
......@@ -150,15 +150,18 @@ ExpressionAction ExpressionAction::arrayJoin(const NameSet & array_joined_column
return a;
}
ExpressionAction ExpressionAction::ordinaryJoin(std::shared_ptr<const Join> join_,
const Names & join_key_names_left,
const NamesAndTypesList & columns_added_by_join_)
ExpressionAction ExpressionAction::ordinaryJoin(
std::shared_ptr<const Join> join_,
const Names & join_key_names_left,
const NamesAndTypesList & columns_added_by_join_,
const NameSet & columns_added_by_join_from_right_keys_)
{
ExpressionAction a;
a.type = JOIN;
a.join = std::move(join_);
a.join_key_names_left = join_key_names_left;
a.columns_added_by_join = columns_added_by_join_;
a.columns_added_by_join_from_right_keys = columns_added_by_join_from_right_keys_;
return a;
}
......@@ -427,7 +430,7 @@ void ExpressionAction::execute(Block & block) const
case JOIN:
{
join->joinBlock(block);
join->joinBlock(block, join_key_names_left, columns_added_by_join_from_right_keys);
break;
}
......@@ -1085,7 +1088,7 @@ BlockInputStreamPtr ExpressionActions::createStreamWithNonJoinedDataIfFullOrRigh
{
for (const auto & action : actions)
if (action.join && (action.join->getKind() == ASTTableJoin::Kind::Full || action.join->getKind() == ASTTableJoin::Kind::Right))
return action.join->createStreamWithNonJoinedRows(source_header, max_block_size);
return action.join->createStreamWithNonJoinedRows(source_header, action.join_key_names_left, max_block_size);
return {};
}
......
......@@ -102,6 +102,7 @@ public:
std::shared_ptr<const Join> join;
Names join_key_names_left;
NamesAndTypesList columns_added_by_join;
NameSet columns_added_by_join_from_right_keys;
/// For PROJECT.
NamesWithAliases projection;
......@@ -118,7 +119,7 @@ public:
static ExpressionAction addAliases(const NamesWithAliases & aliased_columns_);
static ExpressionAction arrayJoin(const NameSet & array_joined_columns, bool array_join_is_left, const Context & context);
static ExpressionAction ordinaryJoin(std::shared_ptr<const Join> join_, const Names & join_key_names_left,
const NamesAndTypesList & columns_added_by_join_);
const NamesAndTypesList & columns_added_by_join_, const NameSet & columns_added_by_join_from_right_keys_);
/// Which columns necessary to perform this action.
Names getNeededColumns() const;
......
......@@ -556,12 +556,13 @@ void ExpressionAnalyzer::addJoinAction(ExpressionActionsPtr & actions, bool only
columns_added_by_join_list.push_back(joined_column.name_and_type);
if (only_types)
actions->add(ExpressionAction::ordinaryJoin(nullptr, analyzedJoin().key_names_left, columns_added_by_join_list));
actions->add(ExpressionAction::ordinaryJoin(nullptr, analyzedJoin().key_names_left,
columns_added_by_join_list, columns_added_by_join_from_right_keys));
else
for (auto & subquery_for_set : subqueries_for_sets)
if (subquery_for_set.second.join)
actions->add(ExpressionAction::ordinaryJoin(subquery_for_set.second.join, analyzedJoin().key_names_left,
columns_added_by_join_list));
columns_added_by_join_list, columns_added_by_join_from_right_keys));
}
bool ExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_types)
......@@ -621,10 +622,8 @@ bool ExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_ty
if (!subquery_for_set.join)
{
JoinPtr join = std::make_shared<Join>(
analyzedJoin().key_names_left, analyzedJoin().key_names_right, columns_added_by_join_from_right_keys,
settings.join_use_nulls, settings.size_limits_for_join,
join_params.kind, join_params.strictness);
JoinPtr join = std::make_shared<Join>(analyzedJoin().key_names_right, settings.join_use_nulls,
settings.size_limits_for_join, join_params.kind, join_params.strictness);
/** For GLOBAL JOINs (in the case, for example, of the push method for executing GLOBAL subqueries), the following occurs
* - in the addExternalStorage function, the JOIN (SELECT ...) subquery is replaced with JOIN _data1,
......
......@@ -30,12 +30,10 @@ namespace ErrorCodes
}
Join::Join(const Names & key_names_left_, const Names & key_names_right_, const NameSet & needed_key_names_right_,
bool use_nulls_, const SizeLimits & limits, ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_)
Join::Join(const Names & key_names_right_, bool use_nulls_, const SizeLimits & limits,
ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_)
: kind(kind_), strictness(strictness_),
key_names_left(key_names_left_),
key_names_right(key_names_right_),
needed_key_names_right(needed_key_names_right_),
use_nulls(use_nulls_),
log(&Logger::get("Join")),
limits(limits)
......@@ -662,7 +660,12 @@ namespace
template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename Maps>
void Join::joinBlockImpl(Block & block, const Maps & maps) const
void Join::joinBlockImpl(
Block & block,
const Names & key_names_left,
const NameSet & needed_key_names_right,
const Block & block_with_columns_to_add,
const Maps & maps) const
{
size_t keys_size = key_names_left.size();
ColumnRawPtrs key_columns(keys_size);
......@@ -734,8 +737,8 @@ void Join::joinBlockImpl(Block & block, const Maps & maps) const
{
const ColumnWithTypeAndName & src_column = sample_block_with_columns_to_add.safeGetByPosition(i);
/// Don't insert column if it's in left block.
if (!block.has(src_column.name))
/// Don't insert column if it's in left block or not explicitly required.
if (!block.has(src_column.name) && block_with_columns_to_add.has(src_column.name))
{
added_columns.push_back(src_column.column->cloneEmpty());
added_columns.back()->reserve(src_column.column->size());
......@@ -746,7 +749,6 @@ void Join::joinBlockImpl(Block & block, const Maps & maps) const
size_t rows = block.rows();
/// Used with ANY INNER JOIN
std::unique_ptr<IColumn::Filter> filter;
bool filter_left_keys = (kind == ASTTableJoin::Kind::Inner || kind == ASTTableJoin::Kind::Right) && strictness == ASTTableJoin::Strictness::Any;
......@@ -875,7 +877,7 @@ void Join::joinBlockImplCross(Block & block) const
}
void Join::checkTypesOfKeys(const Block & block_left, const Block & block_right) const
void Join::checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right) const
{
size_t keys_size = key_names_left.size();
......@@ -895,30 +897,90 @@ void Join::checkTypesOfKeys(const Block & block_left, const Block & block_right)
}
void Join::joinBlock(Block & block) const
static void checkTypeOfKey(const Block & block_left, const Block & block_right)
{
auto & [c1, left_type_origin, left_name] = block_left.safeGetByPosition(0);
auto & [c2, right_type_origin, right_name] = block_right.safeGetByPosition(0);
auto left_type = removeNullable(left_type_origin);
auto right_type = removeNullable(right_type_origin);
if (!left_type->equals(*right_type))
throw Exception("Type mismatch of columns to joinGet by: "
+ left_name + " " + left_type->getName() + " at left, "
+ right_name + " " + right_type->getName() + " at right",
ErrorCodes::TYPE_MISMATCH);
}
DataTypePtr Join::joinGetReturnType(const String & column_name) const
{
std::shared_lock lock(rwlock);
if (!sample_block_with_columns_to_add.has(column_name))
throw Exception("StorageJoin doesn't contain column " + column_name, ErrorCodes::LOGICAL_ERROR);
return sample_block_with_columns_to_add.getByName(column_name).type;
}
template <typename Maps>
void Join::joinGetImpl(Block & block, const String & column_name, const Maps & maps) const
{
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>(
block, {block.getByPosition(0).name}, {}, {sample_block_with_columns_to_add.getByName(column_name)}, maps);
}
// TODO: support composite key
// TODO: return multible columns as named tuple
// TODO: return array of values when strictness == ASTTableJoin::Strictness::All
void Join::joinGet(Block & block, const String & column_name) const
{
std::shared_lock lock(rwlock);
if (key_names_right.size() != 1)
throw Exception("joinGet only supports StorageJoin containing exactly one key", ErrorCodes::LOGICAL_ERROR);
checkTypeOfKey(block, sample_block_with_keys);
if (kind == ASTTableJoin::Kind::Left && strictness == ASTTableJoin::Strictness::Any)
joinGetImpl(block, column_name, maps_any);
else
throw Exception("joinGet only supports StorageJoin of type Left Any", ErrorCodes::LOGICAL_ERROR);
}
void Join::joinBlock(Block & block, const Names & key_names_left, const NameSet & needed_key_names_right) const
{
// std::cerr << "joinBlock: " << block.dumpStructure() << "\n";
std::shared_lock lock(rwlock);
checkTypesOfKeys(block, sample_block_with_keys);
checkTypesOfKeys(block, key_names_left, sample_block_with_keys);
if (kind == ASTTableJoin::Kind::Left && strictness == ASTTableJoin::Strictness::Any)
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>(block, maps_any);
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>(
block, key_names_left, needed_key_names_right, sample_block_with_columns_to_add, maps_any);
else if (kind == ASTTableJoin::Kind::Inner && strictness == ASTTableJoin::Strictness::Any)
joinBlockImpl<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::Any>(block, maps_any);
joinBlockImpl<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::Any>(
block, key_names_left, needed_key_names_right, sample_block_with_columns_to_add, maps_any);
else if (kind == ASTTableJoin::Kind::Left && strictness == ASTTableJoin::Strictness::All)
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::All>(block, maps_all);
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::All>(
block, key_names_left, needed_key_names_right, sample_block_with_columns_to_add, maps_all);
else if (kind == ASTTableJoin::Kind::Inner && strictness == ASTTableJoin::Strictness::All)
joinBlockImpl<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::All>(block, maps_all);
joinBlockImpl<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::All>(
block, key_names_left, needed_key_names_right, sample_block_with_columns_to_add, maps_all);
else if (kind == ASTTableJoin::Kind::Full && strictness == ASTTableJoin::Strictness::Any)
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>(block, maps_any_full);
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>(
block, key_names_left, needed_key_names_right, sample_block_with_columns_to_add, maps_any_full);
else if (kind == ASTTableJoin::Kind::Right && strictness == ASTTableJoin::Strictness::Any)
joinBlockImpl<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::Any>(block, maps_any_full);
joinBlockImpl<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::Any>(
block, key_names_left, needed_key_names_right, sample_block_with_columns_to_add, maps_any_full);
else if (kind == ASTTableJoin::Kind::Full && strictness == ASTTableJoin::Strictness::All)
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::All>(block, maps_all_full);
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::All>(
block, key_names_left, needed_key_names_right, sample_block_with_columns_to_add, maps_all_full);
else if (kind == ASTTableJoin::Kind::Right && strictness == ASTTableJoin::Strictness::All)
joinBlockImpl<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::All>(block, maps_all_full);
joinBlockImpl<ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::All>(
block, key_names_left, needed_key_names_right, sample_block_with_columns_to_add, maps_all_full);
else if (kind == ASTTableJoin::Kind::Cross)
joinBlockImplCross(block);
else
......@@ -995,14 +1057,14 @@ struct AdderNonJoined<ASTTableJoin::Strictness::All, Mapped>
class NonJoinedBlockInputStream : public IProfilingBlockInputStream
{
public:
NonJoinedBlockInputStream(const Join & parent_, const Block & left_sample_block, size_t max_block_size_)
NonJoinedBlockInputStream(const Join & parent_, const Block & left_sample_block, const Names & key_names_left, size_t max_block_size_)
: parent(parent_), max_block_size(max_block_size_)
{
/** left_sample_block contains keys and "left" columns.
* result_sample_block - keys, "left" columns, and "right" columns.
*/
size_t num_keys = parent.key_names_left.size();
size_t num_keys = key_names_left.size();
size_t num_columns_left = left_sample_block.columns() - num_keys;
size_t num_columns_right = parent.sample_block_with_columns_to_add.columns();
......@@ -1019,7 +1081,7 @@ public:
column_indices_keys_and_right.reserve(num_keys + num_columns_right);
std::vector<bool> is_key_column_in_left_block(num_keys + num_columns_left, false);
for (const std::string & key : parent.key_names_left)
for (const std::string & key : key_names_left)
{
size_t key_pos = left_sample_block.getPositionByName(key);
is_key_column_in_left_block[key_pos] = true;
......@@ -1170,9 +1232,9 @@ private:
};
BlockInputStreamPtr Join::createStreamWithNonJoinedRows(const Block & left_sample_block, size_t max_block_size) const
BlockInputStreamPtr Join::createStreamWithNonJoinedRows(const Block & left_sample_block, const Names & key_names_left, size_t max_block_size) const
{
return std::make_shared<NonJoinedBlockInputStream>(*this, left_sample_block, max_block_size);
return std::make_shared<NonJoinedBlockInputStream>(*this, left_sample_block, key_names_left, max_block_size);
}
......
......@@ -219,8 +219,8 @@ struct JoinKeyGetterHashed
class Join
{
public:
Join(const Names & key_names_left_, const Names & key_names_right_, const NameSet & needed_key_names_right_,
bool use_nulls_, const SizeLimits & limits, ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_);
Join(const Names & key_names_right_, bool use_nulls_, const SizeLimits & limits,
ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_);
bool empty() { return type == Type::EMPTY; }
......@@ -237,7 +237,13 @@ public:
/** Join data from the map (that was previously built by calls to insertFromBlock) to the block with data from "left" table.
* Could be called from different threads in parallel.
*/
void joinBlock(Block & block) const;
void joinBlock(Block & block, const Names & key_names_left, const NameSet & needed_key_names_right) const;
/// Infer the return type for joinGet function
DataTypePtr joinGetReturnType(const String & column_name) const;
/// Used by joinGet function that turns StorageJoin into a dictionary
void joinGet(Block & block, const String & column_name) const;
/** Keep "totals" (separate part of dataset, see WITH TOTALS) to use later.
*/
......@@ -251,7 +257,7 @@ public:
* Use only after all calls to joinBlock was done.
* left_sample_block is passed without account of 'use_nulls' setting (columns will be converted to Nullable inside).
*/
BlockInputStreamPtr createStreamWithNonJoinedRows(const Block & left_sample_block, size_t max_block_size) const;
BlockInputStreamPtr createStreamWithNonJoinedRows(const Block & left_sample_block, const Names & key_names_left, size_t max_block_size) const;
/// Number of keys in all built JOIN maps.
size_t getTotalRowCount() const;
......@@ -320,6 +326,16 @@ public:
M(keys256) \
M(hashed)
/// Used for reading from StorageJoin and applying joinGet function
#define APPLY_FOR_JOIN_VARIANTS_LIMITED(M) \
M(key8) \
M(key16) \
M(key32) \
M(key64) \
M(key_string) \
M(key_fixed_string)
enum class Type
{
EMPTY,
......@@ -353,16 +369,13 @@ public:
private:
friend class NonJoinedBlockInputStream;
friend class JoinBlockInputStream;
ASTTableJoin::Kind kind;
ASTTableJoin::Strictness strictness;
/// Names of key columns (columns for equi-JOIN) in "left" table (in the order they appear in USING clause).
const Names key_names_left;
/// Names of key columns (columns for equi-JOIN) in "right" table (in the order they appear in USING clause).
const Names key_names_right;
/// Names of key columns in the "right" table which should stay in block after join.
const NameSet needed_key_names_right;
/// Substitute NULLs for non-JOINed rows.
bool use_nulls;
......@@ -408,12 +421,20 @@ private:
void init(Type type_);
/// Throw an exception if blocks have different types of key columns.
void checkTypesOfKeys(const Block & block_left, const Block & block_right) const;
void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right) const;
template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename Maps>
void joinBlockImpl(Block & block, const Maps & maps) const;
void joinBlockImpl(
Block & block,
const Names & key_names_left,
const NameSet & needed_key_names_right,
const Block & block_with_columns_to_add,
const Maps & maps) const;
void joinBlockImplCross(Block & block) const;
template <typename Maps>
void joinGetImpl(Block & block, const String & column_name, const Maps & maps) const;
};
using JoinPtr = std::shared_ptr<Join>;
......
......@@ -86,11 +86,11 @@ StoragePtr StorageFactory::get(
name = engine_def.name;
if (storage_def->settings && !endsWith(name, "MergeTree") && name != "Kafka")
if (storage_def->settings && !endsWith(name, "MergeTree") && name != "Kafka" && name != "Join")
{
throw Exception(
"Engine " + name + " doesn't support SETTINGS clause. "
"Currently only the MergeTree family of engines and Kafka engine supports it",
"Currently only the MergeTree family of engines, Kafka engine and Join engine support it",
ErrorCodes::BAD_ARGUMENTS);
}
......
#include <Storages/StorageJoin.h>
#include <Storages/StorageFactory.h>
#include <Interpreters/Join.h>
#include <Parsers/ASTCreateQuery.h>
#include <Parsers/ASTIdentifier.h>
#include <Common/typeid_cast.h>
#include <Core/ColumnNumbers.h>
#include <DataStreams/IProfilingBlockInputStream.h>
#include <DataTypes/NestedUtils.h>
#include <Poco/String.h> /// toLower
#include <Poco/File.h>
......@@ -13,6 +17,7 @@ namespace DB
namespace ErrorCodes
{
extern const int UNKNOWN_SET_DATA_VARIANT;
extern const int NO_SUCH_COLUMN_IN_TABLE;
extern const int INCOMPATIBLE_TYPE_OF_JOIN;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
......@@ -24,18 +29,23 @@ StorageJoin::StorageJoin(
const String & path_,
const String & name_,
const Names & key_names_,
ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_,
bool use_nulls_,
SizeLimits limits_,
ASTTableJoin::Kind kind_,
ASTTableJoin::Strictness strictness_,
const ColumnsDescription & columns_)
: StorageSetOrJoinBase{path_, name_, columns_},
key_names(key_names_), kind(kind_), strictness(strictness_)
: StorageSetOrJoinBase{path_, name_, columns_}
, key_names(key_names_)
, use_nulls(use_nulls_)
, limits(limits_)
, kind(kind_)
, strictness(strictness_)
{
for (const auto & key : key_names)
if (!getColumns().hasPhysical(key))
throw Exception{"Key column (" + key + ") does not exist in table declaration.", ErrorCodes::NO_SUCH_COLUMN_IN_TABLE};
/// NOTE StorageJoin doesn't use join_use_nulls setting.
join = std::make_shared<Join>(key_names, key_names, NameSet(), false /* use_nulls */, SizeLimits(), kind, strictness);
join = std::make_shared<Join>(key_names, use_nulls, limits, kind, strictness);
join->setSampleBlock(getSampleBlock().sortColumns());
restore();
}
......@@ -48,7 +58,7 @@ void StorageJoin::truncate(const ASTPtr &)
Poco::File(path + "tmp/").createDirectories();
increment = 0;
join = std::make_shared<Join>(key_names, key_names, NameSet(), false /* use_nulls */, SizeLimits(), kind, strictness);
join = std::make_shared<Join>(key_names, use_nulls, limits, kind, strictness);
join->setSampleBlock(getSampleBlock().sortColumns());
}
......@@ -119,11 +129,237 @@ void registerStorageJoin(StorageFactory & factory)
key_names.push_back(key->name);
}
auto & settings = args.context.getSettingsRef();
auto join_use_nulls = settings.join_use_nulls;
auto max_rows_in_join = settings.max_rows_in_join;
auto max_bytes_in_join = settings.max_bytes_in_join;
auto join_overflow_mode = settings.join_overflow_mode;
if (args.storage_def && args.storage_def->settings)
{
for (const ASTSetQuery::Change & setting : args.storage_def->settings->changes)
{
if (setting.name == "join_use_nulls") join_use_nulls.set(setting.value);
else if (setting.name == "max_rows_in_join") max_rows_in_join.set(setting.value);
else if (setting.name == "max_bytes_in_join") max_bytes_in_join.set(setting.value);
else if (setting.name == "join_overflow_mode") join_overflow_mode.set(setting.value);
else
throw Exception(
"Unknown setting " + setting.name + " for storage " + args.engine_name,
ErrorCodes::BAD_ARGUMENTS);
}
}
return StorageJoin::create(
args.data_path, args.table_name,
key_names, kind, strictness,
args.data_path,
args.table_name,
key_names,
join_use_nulls.value,
SizeLimits{max_rows_in_join.value, max_bytes_in_join.value, join_overflow_mode.value},
kind,
strictness,
args.columns);
});
}
template <typename T>
static const char * rawData(T & t)
{
return reinterpret_cast<const char *>(&t);
}
template <typename T>
static size_t rawSize(T &)
{
return sizeof(T);
}
template <>
const char * rawData(const StringRef & t)
{
return t.data;
}
template <>
size_t rawSize(const StringRef & t)
{
return t.size;
}
class JoinBlockInputStream : public IProfilingBlockInputStream
{
public:
JoinBlockInputStream(const Join & parent_, size_t max_block_size_, Block & sample_block_)
: parent(parent_), lock(parent.rwlock), max_block_size(max_block_size_), sample_block(sample_block_)
{
columns.resize(sample_block.columns());
column_indices.resize(sample_block.columns());
column_with_null.resize(sample_block.columns());
for (size_t i = 0; i < sample_block.columns(); ++i)
{
auto & [_, type, name] = sample_block.getByPosition(i);
if (parent.sample_block_with_keys.has(name))
{
key_pos = i;
column_with_null[i] = parent.sample_block_with_keys.getByName(name).type->isNullable();
}
else
{
auto pos = parent.sample_block_with_columns_to_add.getPositionByName(name);
column_indices[i] = pos;
column_with_null[i] = !parent.sample_block_with_columns_to_add.getByPosition(pos).type->equals(*type);
}
}
}
String getName() const override { return "Join"; }
Block getHeader() const override { return sample_block; }
protected:
Block readImpl() override
{
if (parent.blocks.empty())
return Block();
if (parent.strictness == ASTTableJoin::Strictness::Any)
return createBlock<ASTTableJoin::Strictness::Any>(parent.maps_any);
else if (parent.strictness == ASTTableJoin::Strictness::All)
return createBlock<ASTTableJoin::Strictness::All>(parent.maps_all);
else
throw Exception("Logical error: unknown JOIN strictness (must be ANY or ALL)", ErrorCodes::LOGICAL_ERROR);
}
private:
const Join & parent;
std::shared_lock<std::shared_mutex> lock;
size_t max_block_size;
Block sample_block;
ColumnNumbers column_indices;
std::vector<bool> column_with_null;
std::optional<size_t> key_pos;
MutableColumns columns;
std::unique_ptr<void, std::function<void(void *)>> position; /// type erasure
template <ASTTableJoin::Strictness STRICTNESS, typename Maps>
Block createBlock(const Maps & maps)
{
for (size_t i = 0; i < sample_block.columns(); ++i)
{
const auto & src_col = sample_block.safeGetByPosition(i);
columns[i] = src_col.type->createColumn();
if (column_with_null[i])
{
if (key_pos == i)
{
// unwrap null key column
ColumnNullable & nullable_col = static_cast<ColumnNullable &>(*columns[i]);
columns[i] = nullable_col.getNestedColumnPtr()->assumeMutable();
}
else
// wrap non key column with null
columns[i] = makeNullable(std::move(columns[i]))->assumeMutable();
}
}
size_t rows_added = 0;
switch (parent.type)
{
#define M(TYPE) \
case Join::Type::TYPE: \
rows_added = fillColumns<STRICTNESS>(*maps.TYPE); \
break;
APPLY_FOR_JOIN_VARIANTS_LIMITED(M)
#undef M
default:
throw Exception("Unknown JOIN keys variant for limited use", ErrorCodes::UNKNOWN_SET_DATA_VARIANT);
}
if (!rows_added)
return {};
Block res = sample_block.cloneEmpty();
for (size_t i = 0; i < columns.size(); ++i)
if (column_with_null[i])
{
if (key_pos == i)
res.getByPosition(i).column = makeNullable(std::move(columns[i]))->assumeMutable();
else
{
const ColumnNullable & nullable_col = static_cast<const ColumnNullable &>(*columns[i]);
res.getByPosition(i).column = nullable_col.getNestedColumnPtr();
}
}
else
res.getByPosition(i).column = std::move(columns[i]);
return res;
}
template <ASTTableJoin::Strictness STRICTNESS, typename Map>
size_t fillColumns(const Map & map)
{
size_t rows_added = 0;
if (!position)
position = decltype(position)(
static_cast<void *>(new typename Map::const_iterator(map.begin())),
[](void * ptr) { delete reinterpret_cast<typename Map::const_iterator *>(ptr); });
auto & it = *reinterpret_cast<typename Map::const_iterator *>(position.get());
auto end = map.end();
for (; it != end; ++it)
{
if constexpr (STRICTNESS == ASTTableJoin::Strictness::Any)
{
for (size_t j = 0; j < columns.size(); ++j)
if (j == key_pos)
columns[j]->insertData(rawData(it->first), rawSize(it->first));
else
columns[j]->insertFrom(*it->second.block->getByPosition(column_indices[j]).column.get(), it->second.row_num);
++rows_added;
}
else
for (auto current = &static_cast<const typename Map::mapped_type::Base_t &>(it->second); current != nullptr;
current = current->next)
{
for (size_t j = 0; j < columns.size(); ++j)
if (j == key_pos)
columns[j]->insertData(rawData(it->first), rawSize(it->first));
else
columns[j]->insertFrom(*current->block->getByPosition(column_indices[j]).column.get(), current->row_num);
++rows_added;
}
if (rows_added >= max_block_size)
{
++it;
break;
}
}
return rows_added;
}
};
// TODO: multiple stream read and index read
BlockInputStreams StorageJoin::read(
const Names & column_names,
const SelectQueryInfo & /*query_info*/,
const Context & /*context*/,
QueryProcessingStage::Enum /*processed_stage*/,
size_t max_block_size,
unsigned /*num_streams*/)
{
check(column_names);
Block sample_block = getSampleBlockForColumns(column_names);
return {std::make_shared<JoinBlockInputStream>(*join, max_block_size, sample_block)};
}
}
......@@ -33,8 +33,19 @@ public:
/// Verify that the data structure is suitable for implementing this type of JOIN.
void assertCompatible(ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_) const;
BlockInputStreams read(
const Names & column_names,
const SelectQueryInfo & query_info,
const Context & context,
QueryProcessingStage::Enum processed_stage,
size_t max_block_size,
unsigned num_streams) override;
private:
Block sample_block;
const Names & key_names;
bool use_nulls;
SizeLimits limits;
ASTTableJoin::Kind kind; /// LEFT | INNER ...
ASTTableJoin::Strictness strictness; /// ANY | ALL
......@@ -48,6 +59,8 @@ protected:
const String & path_,
const String & name_,
const Names & key_names_,
bool use_nulls_,
SizeLimits limits_,
ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_,
const ColumnsDescription & columns_);
};
......
--------read--------
def [1,2] 2
abc [0] 1
def [1,2] 2
abc [0] 1
def [1,2] 2
abc [0] 1
def [1,2] 2
abc [0] 1
--------joinGet--------
abc
def
\N
abc
def
[0] 1
DROP TABLE IF EXISTS test.join_any_inner;
DROP TABLE IF EXISTS test.join_any_left;
DROP TABLE IF EXISTS test.join_any_left_null;
DROP TABLE IF EXISTS test.join_all_inner;
DROP TABLE IF EXISTS test.join_all_left;
DROP TABLE IF EXISTS test.join_string_key;
CREATE TABLE test.join_any_inner (s String, x Array(UInt8), k UInt64) ENGINE = Join(ANY, INNER, k);
CREATE TABLE test.join_any_left (s String, x Array(UInt8), k UInt64) ENGINE = Join(ANY, LEFT, k);
CREATE TABLE test.join_all_inner (s String, x Array(UInt8), k UInt64) ENGINE = Join(ALL, INNER, k);
CREATE TABLE test.join_all_left (s String, x Array(UInt8), k UInt64) ENGINE = Join(ALL, LEFT, k);
USE test;
INSERT INTO test.join_any_inner VALUES ('abc', [0], 1), ('def', [1, 2], 2);
INSERT INTO test.join_any_left VALUES ('abc', [0], 1), ('def', [1, 2], 2);
INSERT INTO test.join_all_inner VALUES ('abc', [0], 1), ('def', [1, 2], 2);
INSERT INTO test.join_all_left VALUES ('abc', [0], 1), ('def', [1, 2], 2);
-- read from StorageJoin
SELECT '--------read--------';
SELECT * from test.join_any_inner;
SELECT * from test.join_any_left;
SELECT * from test.join_all_inner;
SELECT * from test.join_all_left;
-- create StorageJoin tables with customized settings
CREATE TABLE test.join_any_left_null (s String, k UInt64) ENGINE = Join(ANY, LEFT, k) SETTINGS join_use_nulls = 1;
INSERT INTO test.join_any_left_null VALUES ('abc', 1), ('def', 2);
-- joinGet
SELECT '--------joinGet--------';
SELECT joinGet('join_any_left', 's', number) FROM numbers(3);
SELECT '';
SELECT joinGet('join_any_left_null', 's', number) FROM numbers(3);
SELECT '';
CREATE TABLE test.join_string_key (s String, x Array(UInt8), k UInt64) ENGINE = Join(ANY, LEFT, s);
INSERT INTO test.join_string_key VALUES ('abc', [0], 1), ('def', [1, 2], 2);
SELECT joinGet('join_string_key', 'x', 'abc'), joinGet('join_string_key', 'k', 'abc');
USE default;
DROP TABLE test.join_any_inner;
DROP TABLE test.join_any_left;
DROP TABLE test.join_any_left_null;
DROP TABLE test.join_all_inner;
DROP TABLE test.join_all_left;
DROP TABLE test.join_string_key;
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册