未验证 提交 2f70e895 编写于 作者: V vdimir

Update StorageJoin locking

Move joinGet into StorageJoin

Protect JoinSource with lock, add test

Add comments about locking logic
上级 f93e5b89
......@@ -25,16 +25,17 @@ ColumnPtr ExecutableFunctionJoinGet<or_null>::execute(const ColumnsWithTypeAndNa
auto key = arguments[i];
keys.emplace_back(std::move(key));
}
return join->join->joinGet(keys, result_columns).column;
return storage_join->joinGet(keys, result_columns).column;
}
template <bool or_null>
ExecutableFunctionImplPtr FunctionJoinGet<or_null>::prepare(const ColumnsWithTypeAndName &) const
{
return std::make_unique<ExecutableFunctionJoinGet<or_null>>(join, DB::Block{{return_type->createColumn(), return_type, attr_name}});
return std::make_unique<ExecutableFunctionJoinGet<or_null>>(storage_join, DB::Block{{return_type->createColumn(), return_type, attr_name}});
}
static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & context)
static std::pair<std::shared_ptr<StorageJoin>, String>
getJoin(const ColumnsWithTypeAndName & arguments, const Context & context)
{
String join_name;
if (const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get()))
......@@ -87,13 +88,12 @@ FunctionBaseImplPtr JoinGetOverloadResolver<or_null>::build(const ColumnsWithTyp
+ ", should be greater or equal to 3",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
auto [storage_join, attr_name] = getJoin(arguments, context);
auto join_holder = storage_join->getJoin();
DataTypes data_types(arguments.size() - 2);
for (size_t i = 2; i < arguments.size(); ++i)
data_types[i - 2] = arguments[i].type;
auto return_type = join_holder->join->joinGetCheckAndGetReturnType(data_types, attr_name, or_null);
auto return_type = storage_join->joinGetCheckAndGetReturnType(data_types, attr_name, or_null);
auto table_lock = storage_join->lockForShare(context.getInitialQueryId(), context.getSettingsRef().lock_acquire_timeout);
return std::make_unique<FunctionJoinGet<or_null>>(table_lock, join_holder, attr_name, data_types, return_type);
return std::make_unique<FunctionJoinGet<or_null>>(table_lock, storage_join, attr_name, data_types, return_type);
}
void registerFunctionJoinGet(FunctionFactory & factory)
......
......@@ -9,15 +9,15 @@ namespace DB
class Context;
class HashJoin;
class HashJoinHolder;
using HashJoinPtr = std::shared_ptr<HashJoin>;
class StorageJoin;
using StorageJoinPtr = std::shared_ptr<StorageJoin>;
template <bool or_null>
class ExecutableFunctionJoinGet final : public IExecutableFunctionImpl
{
public:
ExecutableFunctionJoinGet(std::shared_ptr<HashJoinHolder> join_, const DB::Block & result_columns_)
: join(std::move(join_)), result_columns(result_columns_) {}
ExecutableFunctionJoinGet(StorageJoinPtr storage_join_, const DB::Block & result_columns_)
: storage_join(std::move(storage_join_)), result_columns(result_columns_) {}
static constexpr auto name = or_null ? "joinGetOrNull" : "joinGet";
......@@ -30,7 +30,7 @@ public:
String getName() const override { return name; }
private:
std::shared_ptr<HashJoinHolder> join;
StorageJoinPtr storage_join;
DB::Block result_columns;
};
......@@ -41,10 +41,10 @@ public:
static constexpr auto name = or_null ? "joinGetOrNull" : "joinGet";
FunctionJoinGet(TableLockHolder table_lock_,
std::shared_ptr<HashJoinHolder> join_, String attr_name_,
StorageJoinPtr storage_join_, String attr_name_,
DataTypes argument_types_, DataTypePtr return_type_)
: table_lock(std::move(table_lock_))
, join(join_)
, storage_join(storage_join_)
, attr_name(std::move(attr_name_))
, argument_types(std::move(argument_types_))
, return_type(std::move(return_type_))
......@@ -60,7 +60,7 @@ public:
private:
TableLockHolder table_lock;
std::shared_ptr<HashJoinHolder> join;
StorageJoinPtr storage_join;
const String attr_name;
DataTypes argument_types;
DataTypePtr return_type;
......
......@@ -739,7 +739,7 @@ static JoinPtr tryGetStorageJoin(std::shared_ptr<TableJoin> analyzed_join)
{
if (auto * table = analyzed_join->joined_storage.get())
if (auto * storage_join = dynamic_cast<StorageJoin *>(table))
return storage_join->getJoin(analyzed_join);
return storage_join->getJoinLocked(analyzed_join);
return {};
}
......
......@@ -1203,7 +1203,6 @@ void HashJoin::joinBlockImplCross(Block & block, ExtraBlockPtr & not_processed)
block = block.cloneWithColumns(std::move(dst_columns));
}
DataTypePtr HashJoin::joinGetCheckAndGetReturnType(const DataTypes & data_types, const String & column_name, bool or_null) const
{
size_t num_keys = data_types.size();
......@@ -1235,10 +1234,15 @@ DataTypePtr HashJoin::joinGetCheckAndGetReturnType(const DataTypes & data_types,
return elem.type;
}
template <typename Maps>
ColumnWithTypeAndName HashJoin::joinGetImpl(const Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const
/// TODO: return multiple columns as named tuple
/// TODO: return array of values when strictness == ASTTableJoin::Strictness::All
ColumnWithTypeAndName HashJoin::joinGet(const Block & block, const Block & block_with_columns_to_add) const
{
bool is_valid = (strictness == ASTTableJoin::Strictness::Any || strictness == ASTTableJoin::Strictness::RightAny)
&& kind == ASTTableJoin::Kind::Left;
if (!is_valid)
throw Exception("joinGet only supports StorageJoin of type Left Any", ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN);
/// Assemble the key block with correct names.
Block keys;
for (size_t i = 0; i < block.columns(); ++i)
......@@ -1249,25 +1253,10 @@ ColumnWithTypeAndName HashJoin::joinGetImpl(const Block & block, const Block & b
}
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>(
keys, key_names_right, block_with_columns_to_add, maps_);
keys, key_names_right, block_with_columns_to_add, std::get<MapsOne>(data->maps));
return keys.getByPosition(keys.columns() - 1);
}
/// TODO: return multiple columns as named tuple
/// TODO: return array of values when strictness == ASTTableJoin::Strictness::All
ColumnWithTypeAndName HashJoin::joinGet(const Block & block, const Block & block_with_columns_to_add) const
{
if ((strictness == ASTTableJoin::Strictness::Any || strictness == ASTTableJoin::Strictness::RightAny) &&
kind == ASTTableJoin::Kind::Left)
{
return joinGetImpl(block, block_with_columns_to_add, std::get<MapsOne>(data->maps));
}
else
throw Exception("joinGet only supports StorageJoin of type Left Any", ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN);
}
void HashJoin::joinBlock(Block & block, ExtraBlockPtr & not_processed)
{
const Names & key_names_left = table_join->keyNamesLeft();
......
......@@ -318,6 +318,8 @@ public:
Arena pool;
};
/// We keep correspondence between used_flags and hash table internal buffer.
/// Hash table cannot be modified during HashJoin lifetime and must be protected with lock.
void setLock(std::shared_mutex & rwlock)
{
storage_join_lock = std::shared_lock<std::shared_mutex>(rwlock);
......@@ -354,6 +356,8 @@ private:
/// Flags that indicate that particular row already used in join.
/// Flag is stored for every record in hash map.
/// Number of this flags equals to hashtable buffer size (plus one for zero value).
/// Changes in hash table broke correspondence,
/// so we must guarantee constantness of hash table during HashJoin lifetime (using method setLock)
mutable JoinStuff::JoinUsedFlags used_flags;
Sizes key_sizes;
......@@ -372,6 +376,7 @@ private:
Block totals;
/// Should be set via setLock to protect hash table from modification from StorageJoin
std::shared_lock<std::shared_mutex> storage_join_lock;
void init(Type type_);
......@@ -391,9 +396,6 @@ private:
void joinBlockImplCross(Block & block, ExtraBlockPtr & not_processed) const;
template <typename Maps>
ColumnWithTypeAndName joinGetImpl(const Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const;
static Type chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes);
bool empty() const;
......
......@@ -79,7 +79,7 @@ void StorageJoin::truncate(
}
HashJoinPtr StorageJoin::getJoin(std::shared_ptr<TableJoin> analyzed_join) const
HashJoinPtr StorageJoin::getJoinLocked(std::shared_ptr<TableJoin> analyzed_join) const
{
auto metadata_snapshot = getInMemoryMetadataPtr();
if (!analyzed_join->sameStrictnessAndKind(strictness, kind))
......@@ -127,6 +127,16 @@ std::optional<UInt64> StorageJoin::totalBytes(const Settings &) const
return join->getTotalByteCount();
}
DataTypePtr StorageJoin::joinGetCheckAndGetReturnType(const DataTypes & data_types, const String & column_name, bool or_null) const
{
return join->joinGetCheckAndGetReturnType(data_types, column_name, or_null);
}
ColumnWithTypeAndName StorageJoin::joinGet(const Block & block, const Block & block_with_columns_to_add) const
{
std::shared_lock<std::shared_mutex> lock(rwlock);
return join->joinGet(block, block_with_columns_to_add);
}
void registerStorageJoin(StorageFactory & factory)
{
......@@ -284,23 +294,24 @@ size_t rawSize(const StringRef & t)
class JoinSource : public SourceWithProgress
{
public:
JoinSource(const HashJoin & parent_, UInt64 max_block_size_, Block sample_block_)
JoinSource(HashJoinPtr join_, std::shared_mutex & rwlock, UInt64 max_block_size_, Block sample_block_)
: SourceWithProgress(sample_block_)
, parent(parent_)
, join(join_)
, lock(rwlock)
, max_block_size(max_block_size_)
, sample_block(std::move(sample_block_))
{
column_indices.resize(sample_block.columns());
auto & saved_block = parent.getJoinedData()->sample_block;
auto & saved_block = join->getJoinedData()->sample_block;
for (size_t i = 0; i < sample_block.columns(); ++i)
{
auto & [_, type, name] = sample_block.getByPosition(i);
if (parent.right_table_keys.has(name))
if (join->right_table_keys.has(name))
{
key_pos = i;
const auto & column = parent.right_table_keys.getByName(name);
const auto & column = join->right_table_keys.getByName(name);
restored_block.insert(column);
}
else
......@@ -319,18 +330,20 @@ public:
protected:
Chunk generate() override
{
if (parent.data->blocks.empty())
if (join->data->blocks.empty())
return {};
Chunk chunk;
if (!joinDispatch(parent.kind, parent.strictness, parent.data->maps,
if (!joinDispatch(join->kind, join->strictness, join->data->maps,
[&](auto kind, auto strictness, auto & map) { chunk = createChunk<kind, strictness>(map); }))
throw Exception("Logical error: unknown JOIN strictness", ErrorCodes::LOGICAL_ERROR);
return chunk;
}
private:
const HashJoin & parent;
HashJoinPtr join;
std::shared_lock<std::shared_mutex> lock;
UInt64 max_block_size;
Block sample_block;
Block restored_block; /// sample_block with parent column types
......@@ -348,7 +361,7 @@ private:
size_t rows_added = 0;
switch (parent.data->type)
switch (join->data->type)
{
#define M(TYPE) \
case HashJoin::Type::TYPE: \
......@@ -358,7 +371,7 @@ private:
#undef M
default:
throw Exception("Unsupported JOIN keys in StorageJoin. Type: " + toString(static_cast<UInt32>(parent.data->type)),
throw Exception("Unsupported JOIN keys in StorageJoin. Type: " + toString(static_cast<UInt32>(join->data->type)),
ErrorCodes::UNSUPPORTED_JOIN_KEYS);
}
......@@ -486,7 +499,8 @@ Pipe StorageJoin::read(
{
metadata_snapshot->check(column_names, getVirtuals(), getStorageID());
return Pipe(std::make_shared<JoinSource>(*join, max_block_size, metadata_snapshot->getSampleBlockForColumns(column_names, getVirtuals(), getStorageID())));
Block source_sample_block = metadata_snapshot->getSampleBlockForColumns(column_names, getVirtuals(), getStorageID());
return Pipe(std::make_shared<JoinSource>(join, rwlock, max_block_size, source_sample_block));
}
}
......@@ -14,19 +14,6 @@ class TableJoin;
class HashJoin;
using HashJoinPtr = std::shared_ptr<HashJoin>;
class HashJoinHolder
{
std::shared_lock<std::shared_mutex> lock;
public:
HashJoinPtr join;
HashJoinHolder(std::shared_mutex & rwlock, HashJoinPtr join_)
: lock(rwlock)
, join(join_)
{
}
};
/** Allows you save the state for later use on the right side of the JOIN.
* When inserted into a table, the data will be inserted into the state,
* and also written to the backup file, to restore after the restart.
......@@ -42,9 +29,12 @@ public:
void truncate(const ASTPtr &, const StorageMetadataPtr & metadata_snapshot, const Context &, TableExclusiveLockHolder &) override;
/// Access the innards.
std::shared_ptr<HashJoinHolder> getJoin() { return std::make_shared<HashJoinHolder>(rwlock, join); }
HashJoinPtr getJoin(std::shared_ptr<TableJoin> analyzed_join) const;
/// Return instance of HashJoin holding lock that protects from insertions to StorageJoin.
/// HashJoin relies on structure of hash table that's why we need to return it with locked mutex.
HashJoinPtr getJoinLocked(std::shared_ptr<TableJoin> analyzed_join) const;
DataTypePtr joinGetCheckAndGetReturnType(const DataTypes & data_types, const String & column_name, bool or_null) const;
ColumnWithTypeAndName joinGet(const Block & block, const Block & block_with_columns_to_add) const;
Pipe read(
const Names & column_names,
......@@ -71,8 +61,7 @@ private:
HashJoinPtr join;
/// Protect state for concurrent use in insertFromBlock and joinBlock.
/// Lock hold via HashJoin instance (or HashJoinHolder for joinGet)
/// during all query and block insertions.
/// Lock is stored in HashJoin instance during query and blocks concurrent insertions.
mutable std::shared_mutex rwlock;
void insertBlock(const Block & block) override;
......
......@@ -32,14 +32,25 @@ function read_thread_small()
done
}
function read_thread_select()
{
while true; do
echo "
SELECT * FROM storage_join_race FORMAT Null;
" | $CLICKHOUSE_CLIENT -n
done
}
# https://stackoverflow.com/questions/9954794/execute-a-shell-function-with-timeout
export -f read_thread_big;
export -f read_thread_small;
export -f read_thread_select;
TIMEOUT=20
timeout $TIMEOUT bash -c read_thread_big 2> /dev/null &
timeout $TIMEOUT bash -c read_thread_small 2> /dev/null &
timeout $TIMEOUT bash -c read_thread_select 2> /dev/null &
echo "
INSERT INTO storage_join_race SELECT number AS x, number AS y FROM numbers (10000000);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册