提交 9cfdb70f 编写于 作者: A Alexey Milovidov

Merge branch 'amosbird-jgon'

......@@ -60,7 +60,8 @@ static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & co
return std::make_pair(storage_join, attr_name);
}
FunctionBaseImplPtr JoinGetOverloadResolver::build(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const
template <bool or_null>
FunctionBaseImplPtr JoinGetOverloadResolver<or_null>::build(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const
{
auto [storage_join, attr_name] = getJoin(arguments, context);
auto join = storage_join->getJoin();
......@@ -71,40 +72,52 @@ FunctionBaseImplPtr JoinGetOverloadResolver::build(const ColumnsWithTypeAndName
for (size_t i = 0; i < arguments.size(); ++i)
data_types[i] = arguments[i].type;
auto return_type = join->joinGetReturnType(attr_name);
return std::make_unique<FunctionJoinGet>(table_lock, storage_join, join, attr_name, data_types, return_type);
auto return_type = join->joinGetReturnType(attr_name, or_null);
return std::make_unique<FunctionJoinGet<or_null>>(table_lock, storage_join, join, attr_name, data_types, return_type);
}
DataTypePtr JoinGetOverloadResolver::getReturnType(const ColumnsWithTypeAndName & arguments) const
template <bool or_null>
DataTypePtr JoinGetOverloadResolver<or_null>::getReturnType(const ColumnsWithTypeAndName & arguments) const
{
auto [storage_join, attr_name] = getJoin(arguments, context);
auto join = storage_join->getJoin();
return join->joinGetReturnType(attr_name);
return join->joinGetReturnType(attr_name, or_null);
}
void ExecutableFunctionJoinGet::execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count)
template <bool or_null>
void ExecutableFunctionJoinGet<or_null>::execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count)
{
auto ctn = block.getByPosition(arguments[2]);
if (isColumnConst(*ctn.column))
ctn.column = ctn.column->cloneResized(1);
ctn.name = ""; // make sure the key name never collide with the join columns
Block key_block = {ctn};
join->joinGet(key_block, attr_name);
join->joinGet(key_block, attr_name, or_null);
auto & result_ctn = key_block.getByPosition(1);
if (isColumnConst(*ctn.column))
result_ctn.column = ColumnConst::create(result_ctn.column, input_rows_count);
block.getByPosition(result) = result_ctn;
}
ExecutableFunctionImplPtr FunctionJoinGet::prepare(const Block &, const ColumnNumbers &, size_t) const
template <bool or_null>
ExecutableFunctionImplPtr FunctionJoinGet<or_null>::prepare(const Block &, const ColumnNumbers &, size_t) const
{
return std::make_unique<ExecutableFunctionJoinGet>(join, attr_name);
return std::make_unique<ExecutableFunctionJoinGet<or_null>>(join, attr_name);
}
void registerFunctionJoinGet(FunctionFactory & factory)
{
factory.registerFunction<JoinGetOverloadResolver>();
// joinGet
factory.registerFunction<JoinGetOverloadResolver<false>>();
// joinGetOrNull
factory.registerFunction<JoinGetOverloadResolver<true>>();
}
template class ExecutableFunctionJoinGet<true>;
template class ExecutableFunctionJoinGet<false>;
template class FunctionJoinGet<true>;
template class FunctionJoinGet<false>;
template class JoinGetOverloadResolver<true>;
template class JoinGetOverloadResolver<false>;
}
......@@ -9,13 +9,14 @@ class Context;
class HashJoin;
using HashJoinPtr = std::shared_ptr<HashJoin>;
template <bool or_null>
class ExecutableFunctionJoinGet final : public IExecutableFunctionImpl
{
public:
ExecutableFunctionJoinGet(HashJoinPtr join_, String attr_name_)
: join(std::move(join_)), attr_name(std::move(attr_name_)) {}
static constexpr auto name = "joinGet";
static constexpr auto name = or_null ? "joinGetOrNull" : "joinGet";
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
......@@ -30,10 +31,11 @@ private:
const String attr_name;
};
template <bool or_null>
class FunctionJoinGet final : public IFunctionBaseImpl
{
public:
static constexpr auto name = "joinGet";
static constexpr auto name = or_null ? "joinGetOrNull" : "joinGet";
FunctionJoinGet(TableStructureReadLockHolder table_lock_, StoragePtr storage_join_,
HashJoinPtr join_, String attr_name_,
......@@ -63,10 +65,11 @@ private:
DataTypePtr return_type;
};
template <bool or_null>
class JoinGetOverloadResolver final : public IFunctionOverloadResolverImpl
{
public:
static constexpr auto name = "joinGet";
static constexpr auto name = or_null ? "joinGetOrNull" : "joinGet";
static FunctionOverloadResolverImplPtr create(const Context & context) { return std::make_unique<JoinGetOverloadResolver>(context); }
explicit JoinGetOverloadResolver(const Context & context_) : context(context_) {}
......
......@@ -681,12 +681,10 @@ public:
type_name.reserve(num_columns_to_add);
right_indexes.reserve(num_columns_to_add);
for (size_t i = 0; i < num_columns_to_add; ++i)
for (auto & src_column : block_with_columns_to_add)
{
const ColumnWithTypeAndName & src_column = sample_block_with_columns_to_add.safeGetByPosition(i);
/// Don't insert column if it's in left block or not explicitly required.
if (!block.has(src_column.name) && block_with_columns_to_add.has(src_column.name))
/// Don't insert column if it's in left block
if (!block.has(src_column.name))
addColumn(src_column);
}
......@@ -1158,28 +1156,31 @@ static void checkTypeOfKey(const Block & block_left, const Block & block_right)
}
DataTypePtr HashJoin::joinGetReturnType(const String & column_name) const
DataTypePtr HashJoin::joinGetReturnType(const String & column_name, bool or_null) const
{
std::shared_lock lock(data->rwlock);
if (!sample_block_with_columns_to_add.has(column_name))
throw Exception("StorageJoin doesn't contain column " + column_name, ErrorCodes::LOGICAL_ERROR);
return sample_block_with_columns_to_add.getByName(column_name).type;
auto elem = sample_block_with_columns_to_add.getByName(column_name);
if (or_null)
elem.type = makeNullable(elem.type);
return elem.type;
}
template <typename Maps>
void HashJoin::joinGetImpl(Block & block, const String & column_name, const Maps & maps_) const
void HashJoin::joinGetImpl(Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const
{
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::RightAny>(
block, {block.getByPosition(0).name}, {sample_block_with_columns_to_add.getByName(column_name)}, maps_);
block, {block.getByPosition(0).name}, block_with_columns_to_add, maps_);
}
// TODO: support composite key
// TODO: return multiple columns as named tuple
// TODO: return array of values when strictness == ASTTableJoin::Strictness::All
void HashJoin::joinGet(Block & block, const String & column_name) const
void HashJoin::joinGet(Block & block, const String & column_name, bool or_null) const
{
std::shared_lock lock(data->rwlock);
......@@ -1188,10 +1189,15 @@ void HashJoin::joinGet(Block & block, const String & column_name) const
checkTypeOfKey(block, right_table_keys);
auto elem = sample_block_with_columns_to_add.getByName(column_name);
if (or_null)
elem.type = makeNullable(elem.type);
elem.column = elem.type->createColumn();
if ((strictness == ASTTableJoin::Strictness::Any || strictness == ASTTableJoin::Strictness::RightAny) &&
kind == ASTTableJoin::Kind::Left)
{
joinGetImpl(block, column_name, std::get<MapsOne>(data->maps));
joinGetImpl(block, {elem}, std::get<MapsOne>(data->maps));
}
else
throw Exception("joinGet only supports StorageJoin of type Left Any", ErrorCodes::LOGICAL_ERROR);
......
......@@ -161,10 +161,10 @@ public:
void joinBlock(Block & block, ExtraBlockPtr & not_processed) override;
/// Infer the return type for joinGet function
DataTypePtr joinGetReturnType(const String & column_name) const;
DataTypePtr joinGetReturnType(const String & column_name, bool or_null) const;
/// Used by joinGet function that turns StorageJoin into a dictionary
void joinGet(Block & block, const String & column_name) const;
void joinGet(Block & block, const String & column_name, bool or_null) const;
/** Keep "totals" (separate part of dataset, see WITH TOTALS) to use later.
*/
......@@ -382,7 +382,7 @@ private:
void joinBlockImplCross(Block & block, ExtraBlockPtr & not_processed) const;
template <typename Maps>
void joinGetImpl(Block & block, const String & column_name, const Maps & maps) const;
void joinGetImpl(Block & block, const Block & block_with_columns_to_add, const Maps & maps_) const;
static Type chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes);
};
......
DROP TABLE IF EXISTS join_test;
CREATE TABLE join_test (id UInt16, num UInt16) engine = Join(ANY, LEFT, id);
SELECT joinGetOrNull('join_test', 'num', 500);
DROP TABLE join_test;
CREATE TABLE join_test (id UInt16, num Nullable(UInt16)) engine = Join(ANY, LEFT, id);
SELECT joinGetOrNull('join_test', 'num', 500);
DROP TABLE join_test;
CREATE TABLE join_test (id UInt16, num Array(UInt16)) engine = Join(ANY, LEFT, id);
SELECT joinGetOrNull('join_test', 'num', 500); -- { serverError 43 }
DROP TABLE join_test;
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册