diff --git a/src/Functions/FunctionJoinGet.cpp b/src/Functions/FunctionJoinGet.cpp index 3a2649c11a8f26a9abcd4e1467960f8c40d11461..a2e4e2d17905c57a1fdf463d2725eeb1cd3e5da8 100644 --- a/src/Functions/FunctionJoinGet.cpp +++ b/src/Functions/FunctionJoinGet.cpp @@ -25,16 +25,17 @@ ColumnPtr ExecutableFunctionJoinGet::execute(const ColumnsWithTypeAndNa auto key = arguments[i]; keys.emplace_back(std::move(key)); } - return join->join->joinGet(keys, result_columns).column; + return storage_join->joinGet(keys, result_columns).column; } template ExecutableFunctionImplPtr FunctionJoinGet::prepare(const ColumnsWithTypeAndName &) const { - return std::make_unique>(join, DB::Block{{return_type->createColumn(), return_type, attr_name}}); + return std::make_unique>(storage_join, DB::Block{{return_type->createColumn(), return_type, attr_name}}); } -static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & context) +static std::pair, String> +getJoin(const ColumnsWithTypeAndName & arguments, const Context & context) { String join_name; if (const auto * name_col = checkAndGetColumnConst(arguments[0].column.get())) @@ -87,13 +88,12 @@ FunctionBaseImplPtr JoinGetOverloadResolver::build(const ColumnsWithTyp + ", should be greater or equal to 3", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); auto [storage_join, attr_name] = getJoin(arguments, context); - auto join_holder = storage_join->getJoin(); DataTypes data_types(arguments.size() - 2); for (size_t i = 2; i < arguments.size(); ++i) data_types[i - 2] = arguments[i].type; - auto return_type = join_holder->join->joinGetCheckAndGetReturnType(data_types, attr_name, or_null); + auto return_type = storage_join->joinGetCheckAndGetReturnType(data_types, attr_name, or_null); auto table_lock = storage_join->lockForShare(context.getInitialQueryId(), context.getSettingsRef().lock_acquire_timeout); - return std::make_unique>(table_lock, join_holder, attr_name, data_types, return_type); + return std::make_unique>(table_lock, storage_join, attr_name, data_types, return_type); } void registerFunctionJoinGet(FunctionFactory & factory) diff --git a/src/Functions/FunctionJoinGet.h b/src/Functions/FunctionJoinGet.h index 820c6cd3fa2abcebb592ff785d5a8bb1c36aaf0b..f7d3858e9022d7dbe7cdf9ae64ac9c02aa835f55 100644 --- a/src/Functions/FunctionJoinGet.h +++ b/src/Functions/FunctionJoinGet.h @@ -9,15 +9,15 @@ namespace DB class Context; class HashJoin; -class HashJoinHolder; -using HashJoinPtr = std::shared_ptr; +class StorageJoin; +using StorageJoinPtr = std::shared_ptr; template class ExecutableFunctionJoinGet final : public IExecutableFunctionImpl { public: - ExecutableFunctionJoinGet(std::shared_ptr join_, const DB::Block & result_columns_) - : join(std::move(join_)), result_columns(result_columns_) {} + ExecutableFunctionJoinGet(StorageJoinPtr storage_join_, const DB::Block & result_columns_) + : storage_join(std::move(storage_join_)), result_columns(result_columns_) {} static constexpr auto name = or_null ? "joinGetOrNull" : "joinGet"; @@ -30,7 +30,7 @@ public: String getName() const override { return name; } private: - std::shared_ptr join; + StorageJoinPtr storage_join; DB::Block result_columns; }; @@ -41,10 +41,10 @@ public: static constexpr auto name = or_null ? "joinGetOrNull" : "joinGet"; FunctionJoinGet(TableLockHolder table_lock_, - std::shared_ptr join_, String attr_name_, + StorageJoinPtr storage_join_, String attr_name_, DataTypes argument_types_, DataTypePtr return_type_) : table_lock(std::move(table_lock_)) - , join(join_) + , storage_join(storage_join_) , attr_name(std::move(attr_name_)) , argument_types(std::move(argument_types_)) , return_type(std::move(return_type_)) @@ -60,7 +60,7 @@ public: private: TableLockHolder table_lock; - std::shared_ptr join; + StorageJoinPtr storage_join; const String attr_name; DataTypes argument_types; DataTypePtr return_type; diff --git a/src/Interpreters/ExpressionAnalyzer.cpp b/src/Interpreters/ExpressionAnalyzer.cpp index 660718549b34bdc3cefbd5f08c19bc768fa98ca9..55dc622a976f74df40ce10c1a0edad57ee0bc0fb 100644 --- a/src/Interpreters/ExpressionAnalyzer.cpp +++ b/src/Interpreters/ExpressionAnalyzer.cpp @@ -739,7 +739,7 @@ static JoinPtr tryGetStorageJoin(std::shared_ptr analyzed_join) { if (auto * table = analyzed_join->joined_storage.get()) if (auto * storage_join = dynamic_cast(table)) - return storage_join->getJoin(analyzed_join); + return storage_join->getJoinLocked(analyzed_join); return {}; } diff --git a/src/Interpreters/HashJoin.cpp b/src/Interpreters/HashJoin.cpp index 9c64b9522b9357aae233aab9815bf7b5ce8a1b63..942be9d172d768a6b89f250b438317b36efe695c 100644 --- a/src/Interpreters/HashJoin.cpp +++ b/src/Interpreters/HashJoin.cpp @@ -1203,7 +1203,6 @@ void HashJoin::joinBlockImplCross(Block & block, ExtraBlockPtr & not_processed) block = block.cloneWithColumns(std::move(dst_columns)); } - DataTypePtr HashJoin::joinGetCheckAndGetReturnType(const DataTypes & data_types, const String & column_name, bool or_null) const { size_t num_keys = data_types.size(); @@ -1235,10 +1234,15 @@ DataTypePtr HashJoin::joinGetCheckAndGetReturnType(const DataTypes & data_types, return elem.type; } - -template -ColumnWithTypeAndName HashJoin::joinGetImpl(const Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const +/// TODO: return multiple columns as named tuple +/// TODO: return array of values when strictness == ASTTableJoin::Strictness::All +ColumnWithTypeAndName HashJoin::joinGet(const Block & block, const Block & block_with_columns_to_add) const { + bool is_valid = (strictness == ASTTableJoin::Strictness::Any || strictness == ASTTableJoin::Strictness::RightAny) + && kind == ASTTableJoin::Kind::Left; + if (!is_valid) + throw Exception("joinGet only supports StorageJoin of type Left Any", ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN); + /// Assemble the key block with correct names. Block keys; for (size_t i = 0; i < block.columns(); ++i) @@ -1249,25 +1253,10 @@ ColumnWithTypeAndName HashJoin::joinGetImpl(const Block & block, const Block & b } joinBlockImpl( - keys, key_names_right, block_with_columns_to_add, maps_); + keys, key_names_right, block_with_columns_to_add, std::get(data->maps)); return keys.getByPosition(keys.columns() - 1); } - -/// TODO: return multiple columns as named tuple -/// TODO: return array of values when strictness == ASTTableJoin::Strictness::All -ColumnWithTypeAndName HashJoin::joinGet(const Block & block, const Block & block_with_columns_to_add) const -{ - if ((strictness == ASTTableJoin::Strictness::Any || strictness == ASTTableJoin::Strictness::RightAny) && - kind == ASTTableJoin::Kind::Left) - { - return joinGetImpl(block, block_with_columns_to_add, std::get(data->maps)); - } - else - throw Exception("joinGet only supports StorageJoin of type Left Any", ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN); -} - - void HashJoin::joinBlock(Block & block, ExtraBlockPtr & not_processed) { const Names & key_names_left = table_join->keyNamesLeft(); diff --git a/src/Interpreters/HashJoin.h b/src/Interpreters/HashJoin.h index d212e16b1758bdc97022cd2f89d5c398bd3aefa1..075634b348da3c015a86f203ec520c775203e1f8 100644 --- a/src/Interpreters/HashJoin.h +++ b/src/Interpreters/HashJoin.h @@ -318,6 +318,8 @@ public: Arena pool; }; + /// We keep correspondence between used_flags and hash table internal buffer. + /// Hash table cannot be modified during HashJoin lifetime and must be protected with lock. void setLock(std::shared_mutex & rwlock) { storage_join_lock = std::shared_lock(rwlock); @@ -354,6 +356,8 @@ private: /// Flags that indicate that particular row already used in join. /// Flag is stored for every record in hash map. /// Number of this flags equals to hashtable buffer size (plus one for zero value). + /// Changes in hash table broke correspondence, + /// so we must guarantee constantness of hash table during HashJoin lifetime (using method setLock) mutable JoinStuff::JoinUsedFlags used_flags; Sizes key_sizes; @@ -372,6 +376,7 @@ private: Block totals; + /// Should be set via setLock to protect hash table from modification from StorageJoin std::shared_lock storage_join_lock; void init(Type type_); @@ -391,9 +396,6 @@ private: void joinBlockImplCross(Block & block, ExtraBlockPtr & not_processed) const; - template - ColumnWithTypeAndName joinGetImpl(const Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const; - static Type chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes); bool empty() const; diff --git a/src/Storages/StorageJoin.cpp b/src/Storages/StorageJoin.cpp index bfe866ed223b9c849f3c75c69c8d09e1791aab24..a449cebba5146e4f7ee0022fdd88dc5dd6c66c92 100644 --- a/src/Storages/StorageJoin.cpp +++ b/src/Storages/StorageJoin.cpp @@ -79,7 +79,7 @@ void StorageJoin::truncate( } -HashJoinPtr StorageJoin::getJoin(std::shared_ptr analyzed_join) const +HashJoinPtr StorageJoin::getJoinLocked(std::shared_ptr analyzed_join) const { auto metadata_snapshot = getInMemoryMetadataPtr(); if (!analyzed_join->sameStrictnessAndKind(strictness, kind)) @@ -127,6 +127,16 @@ std::optional StorageJoin::totalBytes(const Settings &) const return join->getTotalByteCount(); } +DataTypePtr StorageJoin::joinGetCheckAndGetReturnType(const DataTypes & data_types, const String & column_name, bool or_null) const +{ + return join->joinGetCheckAndGetReturnType(data_types, column_name, or_null); +} + +ColumnWithTypeAndName StorageJoin::joinGet(const Block & block, const Block & block_with_columns_to_add) const +{ + std::shared_lock lock(rwlock); + return join->joinGet(block, block_with_columns_to_add); +} void registerStorageJoin(StorageFactory & factory) { @@ -284,23 +294,24 @@ size_t rawSize(const StringRef & t) class JoinSource : public SourceWithProgress { public: - JoinSource(const HashJoin & parent_, UInt64 max_block_size_, Block sample_block_) + JoinSource(HashJoinPtr join_, std::shared_mutex & rwlock, UInt64 max_block_size_, Block sample_block_) : SourceWithProgress(sample_block_) - , parent(parent_) + , join(join_) + , lock(rwlock) , max_block_size(max_block_size_) , sample_block(std::move(sample_block_)) { column_indices.resize(sample_block.columns()); - auto & saved_block = parent.getJoinedData()->sample_block; + auto & saved_block = join->getJoinedData()->sample_block; for (size_t i = 0; i < sample_block.columns(); ++i) { auto & [_, type, name] = sample_block.getByPosition(i); - if (parent.right_table_keys.has(name)) + if (join->right_table_keys.has(name)) { key_pos = i; - const auto & column = parent.right_table_keys.getByName(name); + const auto & column = join->right_table_keys.getByName(name); restored_block.insert(column); } else @@ -319,18 +330,20 @@ public: protected: Chunk generate() override { - if (parent.data->blocks.empty()) + if (join->data->blocks.empty()) return {}; Chunk chunk; - if (!joinDispatch(parent.kind, parent.strictness, parent.data->maps, + if (!joinDispatch(join->kind, join->strictness, join->data->maps, [&](auto kind, auto strictness, auto & map) { chunk = createChunk(map); })) throw Exception("Logical error: unknown JOIN strictness", ErrorCodes::LOGICAL_ERROR); return chunk; } private: - const HashJoin & parent; + HashJoinPtr join; + std::shared_lock lock; + UInt64 max_block_size; Block sample_block; Block restored_block; /// sample_block with parent column types @@ -348,7 +361,7 @@ private: size_t rows_added = 0; - switch (parent.data->type) + switch (join->data->type) { #define M(TYPE) \ case HashJoin::Type::TYPE: \ @@ -358,7 +371,7 @@ private: #undef M default: - throw Exception("Unsupported JOIN keys in StorageJoin. Type: " + toString(static_cast(parent.data->type)), + throw Exception("Unsupported JOIN keys in StorageJoin. Type: " + toString(static_cast(join->data->type)), ErrorCodes::UNSUPPORTED_JOIN_KEYS); } @@ -486,7 +499,8 @@ Pipe StorageJoin::read( { metadata_snapshot->check(column_names, getVirtuals(), getStorageID()); - return Pipe(std::make_shared(*join, max_block_size, metadata_snapshot->getSampleBlockForColumns(column_names, getVirtuals(), getStorageID()))); + Block source_sample_block = metadata_snapshot->getSampleBlockForColumns(column_names, getVirtuals(), getStorageID()); + return Pipe(std::make_shared(join, rwlock, max_block_size, source_sample_block)); } } diff --git a/src/Storages/StorageJoin.h b/src/Storages/StorageJoin.h index 7e4dea5d22317141a281e7c1cf2e642d17dfe4df..49f2be825b2827dcac41f85597d9c403ddf06676 100644 --- a/src/Storages/StorageJoin.h +++ b/src/Storages/StorageJoin.h @@ -14,19 +14,6 @@ class TableJoin; class HashJoin; using HashJoinPtr = std::shared_ptr; -class HashJoinHolder -{ - std::shared_lock lock; -public: - HashJoinPtr join; - - HashJoinHolder(std::shared_mutex & rwlock, HashJoinPtr join_) - : lock(rwlock) - , join(join_) - { - } -}; - /** Allows you save the state for later use on the right side of the JOIN. * When inserted into a table, the data will be inserted into the state, * and also written to the backup file, to restore after the restart. @@ -42,9 +29,12 @@ public: void truncate(const ASTPtr &, const StorageMetadataPtr & metadata_snapshot, const Context &, TableExclusiveLockHolder &) override; - /// Access the innards. - std::shared_ptr getJoin() { return std::make_shared(rwlock, join); } - HashJoinPtr getJoin(std::shared_ptr analyzed_join) const; + /// Return instance of HashJoin holding lock that protects from insertions to StorageJoin. + /// HashJoin relies on structure of hash table that's why we need to return it with locked mutex. + HashJoinPtr getJoinLocked(std::shared_ptr analyzed_join) const; + + DataTypePtr joinGetCheckAndGetReturnType(const DataTypes & data_types, const String & column_name, bool or_null) const; + ColumnWithTypeAndName joinGet(const Block & block, const Block & block_with_columns_to_add) const; Pipe read( const Names & column_names, @@ -71,8 +61,7 @@ private: HashJoinPtr join; /// Protect state for concurrent use in insertFromBlock and joinBlock. - /// Lock hold via HashJoin instance (or HashJoinHolder for joinGet) - /// during all query and block insertions. + /// Lock is stored in HashJoin instance during query and blocks concurrent insertions. mutable std::shared_mutex rwlock; void insertBlock(const Block & block) override; diff --git a/tests/queries/0_stateless/01732_race_condition_storage_long.reference b/tests/queries/0_stateless/01732_race_condition_storage_join_long.reference similarity index 100% rename from tests/queries/0_stateless/01732_race_condition_storage_long.reference rename to tests/queries/0_stateless/01732_race_condition_storage_join_long.reference diff --git a/tests/queries/0_stateless/01732_race_condition_storage_long.sh b/tests/queries/0_stateless/01732_race_condition_storage_join_long.sh similarity index 83% rename from tests/queries/0_stateless/01732_race_condition_storage_long.sh rename to tests/queries/0_stateless/01732_race_condition_storage_join_long.sh index 0ce6f2b0c7a0b76822f8a498d49070dcb020ba99..b7dd76760d4d83b37998882b4aabebd4e1471a70 100755 --- a/tests/queries/0_stateless/01732_race_condition_storage_long.sh +++ b/tests/queries/0_stateless/01732_race_condition_storage_join_long.sh @@ -32,14 +32,25 @@ function read_thread_small() done } +function read_thread_select() +{ + while true; do + echo " + SELECT * FROM storage_join_race FORMAT Null; + " | $CLICKHOUSE_CLIENT -n + done +} + # https://stackoverflow.com/questions/9954794/execute-a-shell-function-with-timeout export -f read_thread_big; export -f read_thread_small; +export -f read_thread_select; TIMEOUT=20 timeout $TIMEOUT bash -c read_thread_big 2> /dev/null & timeout $TIMEOUT bash -c read_thread_small 2> /dev/null & +timeout $TIMEOUT bash -c read_thread_select 2> /dev/null & echo " INSERT INTO storage_join_race SELECT number AS x, number AS y FROM numbers (10000000);