提交 6fce028b 编写于 作者: N Nikolai Kochetov

Refactor ColumnsHashing.

上级 d2074985
此差异已折叠。
#pragma once
#include <Columns/IColumn.h>
#include <Interpreters/AggregationCommon.h>
namespace DB
{
namespace ColumnsHashing
{
namespace columns_hashing_impl
{
template <typename Value, bool consecutive_keys_optimization_>
struct LastElementCache
{
static constexpr bool consecutive_keys_optimization = consecutive_keys_optimization_;
Value value;
bool empty = true;
bool found = false;
bool check(const Value & value_) { return !empty && value == value_; }
template <typename Key>
bool check(const Key & key) { return !empty && value.first == key; }
};
template <typename Data>
struct LastElementCache<Data, false>
{
static constexpr bool consecutive_keys_optimization = false;
};
template <typename Mapped>
class EmplaceResultImpl
{
Mapped & value;
Mapped & cached_value;
bool inserted;
public:
EmplaceResultImpl(Mapped & value, Mapped & cached_value, bool inserted)
: value(value), cached_value(cached_value), inserted(inserted) {}
bool isInserted() const { return inserted; }
const auto & getMapped() const { return value; }
void setMapped(const Mapped & mapped) { value = cached_value = mapped; }
};
template <>
class EmplaceResultImpl<void>
{
bool inserted;
public:
explicit EmplaceResultImpl(bool inserted) : inserted(inserted) {}
bool isInserted() const { return inserted; }
};
template <typename Mapped>
class FindResultImpl
{
Mapped value;
bool found;
public:
FindResultImpl(Mapped value, bool found) : value(value), found(found) {}
bool isFound() const { return found; }
const Mapped & getMapped() const { return value; }
};
template <>
class FindResultImpl<void>
{
bool found;
public:
explicit FindResultImpl(bool found) : found(found) {}
bool isFound() const { return found; }
};
template <typename Value, typename Mapped, bool consecutive_keys_optimization>
struct HashMethodBase
{
using EmplaceResult = EmplaceResultImpl<Mapped>;
using FindResult = FindResultImpl<Mapped>;
static constexpr bool has_mapped = !std::is_same<Mapped, void>::value;
using Cache = LastElementCache<Value, consecutive_keys_optimization>;
protected:
Cache cache;
HashMethodBase()
{
if constexpr (has_mapped && consecutive_keys_optimization)
{
/// Init PairNoInit elements.
cache.value.second = Mapped();
using Key = decltype(cache.value.first);
cache.value.first = Key();
}
}
template <typename Data, typename Key>
ALWAYS_INLINE EmplaceResult emplaceKeyImpl(Key key, Data & data, typename Data::iterator & it)
{
if constexpr (Cache::consecutive_keys_optimization)
{
if (cache.found && cache.check(key))
{
if constexpr (has_mapped)
return EmplaceResult(cache.value.second, cache.value.second, false);
else
return EmplaceResult(false);
}
}
bool inserted = false;
data.emplace(key, it, inserted);
Mapped * cached = &it->second;
if constexpr (consecutive_keys_optimization)
{
cache.value = *it;
cache.found = true;
cache.empty = false;
cached = &cache.value.second;
}
if constexpr (has_mapped)
return EmplaceResult(it->second, *cached, inserted);
else
return EmplaceResult(inserted);
}
template <typename Data, typename Key>
ALWAYS_INLINE FindResult findKeyImpl(Key key, Data & data)
{
if constexpr (Cache::consecutive_keys_optimization)
{
if (cache.check(key))
{
if constexpr (has_mapped)
return FindResult(cache.found ? cache.value.second : Mapped(), cache.found);
else
return FindResult(cache.found);
}
}
auto it = data.find(key);
bool found = it != data.end();
if constexpr (consecutive_keys_optimization)
{
cache.found = found;
cache.empty = false;
if (found)
cache.value = *it;
else
{
if constexpr (has_mapped)
cache.value.first = key;
else
cache.value = key;
}
}
if constexpr (has_mapped)
return FindResult(found ? it->second : Mapped(), found);
else
return FindResult(found);
}
};
template <typename T>
struct MappedCache : public PaddedPODArray<T> {};
template <>
struct MappedCache<void> {};
/// This class is designed to provide the functionality that is required for
/// supporting nullable keys in HashMethodKeysFixed. If there are
/// no nullable keys, this class is merely implemented as an empty shell.
template <typename Key, bool has_nullable_keys>
class BaseStateKeysFixed;
/// Case where nullable keys are supported.
template <typename Key>
class BaseStateKeysFixed<Key, true>
{
protected:
void init(const ColumnRawPtrs & key_columns)
{
null_maps.reserve(key_columns.size());
actual_columns.reserve(key_columns.size());
for (const auto & col : key_columns)
{
if (col->isColumnNullable())
{
const auto & nullable_col = static_cast<const ColumnNullable &>(*col);
actual_columns.push_back(&nullable_col.getNestedColumn());
null_maps.push_back(&nullable_col.getNullMapColumn());
}
else
{
actual_columns.push_back(col);
null_maps.push_back(nullptr);
}
}
}
/// Return the columns which actually contain the values of the keys.
/// For a given key column, if it is nullable, we return its nested
/// column. Otherwise we return the key column itself.
inline const ColumnRawPtrs & getActualColumns() const
{
return actual_columns;
}
/// Create a bitmap that indicates whether, for a particular row,
/// a key column bears a null value or not.
KeysNullMap<Key> createBitmap(size_t row) const
{
KeysNullMap<Key> bitmap{};
for (size_t k = 0; k < null_maps.size(); ++k)
{
if (null_maps[k] != nullptr)
{
const auto & null_map = static_cast<const ColumnUInt8 &>(*null_maps[k]).getData();
if (null_map[row] == 1)
{
size_t bucket = k / 8;
size_t offset = k % 8;
bitmap[bucket] |= UInt8(1) << offset;
}
}
}
return bitmap;
}
private:
ColumnRawPtrs actual_columns;
ColumnRawPtrs null_maps;
};
/// Case where nullable keys are not supported.
template <typename Key>
class BaseStateKeysFixed<Key, false>
{
protected:
void init(const ColumnRawPtrs & columns) { actual_columns = columns; }
const ColumnRawPtrs & getActualColumns() const { return actual_columns; }
KeysNullMap<Key> createBitmap(size_t) const
{
throw Exception{"Internal error: calling createBitmap() for non-nullable keys"
" is forbidden", ErrorCodes::LOGICAL_ERROR};
}
private:
ColumnRawPtrs actual_columns;
};
}
}
}
......@@ -609,20 +609,34 @@ void NO_INLINE Aggregator::executeImplCase(
/// NOTE When editing this code, also pay attention to SpecializedAggregator.h.
/// For all rows.
AggregateDataPtr value = nullptr;
for (size_t i = 0; i < rows; ++i)
{
bool inserted = false; /// Inserted a new key, or was this key already?
AggregateDataPtr * aggregate_data = nullptr;
AggregateDataPtr aggregate_data = nullptr;
if constexpr (!no_more_keys) /// Insert.
aggregate_data = state.emplaceKey(method.data, i, inserted, *aggregates_pool);
{
auto emplace_result = state.emplaceKey(method.data, i, *aggregates_pool);
/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
if (emplace_result.isInserted())
{
/// exception-safety - if you can not allocate memory or create states, then destructors will not be called.
emplace_result.setMapped(nullptr);
aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states);
createAggregateStates(aggregate_data);
emplace_result.setMapped(aggregate_data);
}
else
aggregate_data = emplace_result.getMapped();
}
else
{
/// Add only if the key already exists.
bool found = false;
aggregate_data = state.findKey(method.data, i, found, *aggregates_pool);
auto find_result = state.findKey(method.data, i, *aggregates_pool);
if (find_result.isFound())
aggregate_data = find_result.getMapped();
}
/// aggregate_date == nullptr means that the new key did not fit in the hash table because of no_more_keys.
......@@ -631,20 +645,7 @@ void NO_INLINE Aggregator::executeImplCase(
if (!aggregate_data && !overflow_row)
continue;
/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
if (inserted)
{
/// exception-safety - if you can not allocate memory or create states, then destructors will not be called.
*aggregate_data = nullptr;
AggregateDataPtr place = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states);
createAggregateStates(place);
*aggregate_data = place;
state.cacheData(i, place);
}
value = aggregate_data ? *aggregate_data : overflow_row;
AggregateDataPtr value = aggregate_data ? aggregate_data : overflow_row;
/// Add values to the aggregate functions.
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst)
......@@ -1951,17 +1952,28 @@ void NO_INLINE Aggregator::mergeStreamsImplCase(
size_t rows = block.rows();
for (size_t i = 0; i < rows; ++i)
{
typename Table::iterator it;
AggregateDataPtr * aggregate_data = nullptr;
bool inserted = false; /// Inserted a new key, or was this key already?
AggregateDataPtr aggregate_data = nullptr;
if (!no_more_keys)
aggregate_data = state.emplaceKey(data, i, inserted, *aggregates_pool);
{
auto emplace_result = state.emplaceKey(data, i, *aggregates_pool);
if (emplace_result.isInserted())
{
emplace_result.setMapped(nullptr);
aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states);
createAggregateStates(aggregate_data);
emplace_result.setMapped(aggregate_data);
}
else
aggregate_data = emplace_result.getMapped();
}
else
{
bool found;
aggregate_data = state.findKey(data, i, found, *aggregates_pool);
auto find_result = state.findKey(data, i, *aggregates_pool);
if (find_result.isFound())
aggregate_data = find_result.getMapped();
}
/// aggregate_date == nullptr means that the new key did not fit in the hash table because of no_more_keys.
......@@ -1970,19 +1982,7 @@ void NO_INLINE Aggregator::mergeStreamsImplCase(
if (!aggregate_data && !overflow_row)
continue;
/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
if (inserted)
{
*aggregate_data = nullptr;
AggregateDataPtr place = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states);
createAggregateStates(place);
*aggregate_data = place;
state.cacheData(i, place);
}
AggregateDataPtr value = aggregate_data ? *aggregate_data : overflow_row;
AggregateDataPtr value = aggregate_data ? aggregate_data : overflow_row;
/// Merge state of aggregate functions.
for (size_t j = 0; j < params.aggregates_size; ++j)
......
......@@ -158,7 +158,7 @@ struct AggregationMethodOneNumber
AggregationMethodOneNumber(const Other & other) : data(other.data) {}
/// To use one `Method` in different threads, use different `State`.
using State = ColumnsHashing::HashMethodOneNumber<Data, FieldType>;
using State = ColumnsHashing::HashMethodOneNumber<typename Data::value_type, Mapped, FieldType>;
/// Use optimization for low cardinality.
static const bool low_cardinality_optimization = false;
......@@ -188,7 +188,7 @@ struct AggregationMethodString
template <typename Other>
AggregationMethodString(const Other & other) : data(other.data) {}
using State = ColumnsHashing::HashMethodString<Data>;
using State = ColumnsHashing::HashMethodString<typename Data::value_type, Mapped>;
static const bool low_cardinality_optimization = false;
......@@ -216,7 +216,7 @@ struct AggregationMethodFixedString
template <typename Other>
AggregationMethodFixedString(const Other & other) : data(other.data) {}
using State = ColumnsHashing::HashMethodFixedString<Data>;
using State = ColumnsHashing::HashMethodFixedString<typename Data::value_type, Mapped>;
static const bool low_cardinality_optimization = false;
......@@ -246,7 +246,7 @@ struct AggregationMethodSingleLowCardinalityColumn : public SingleColumnMethod
template <typename Other>
explicit AggregationMethodSingleLowCardinalityColumn(const Other & other) : Base(other) {}
using State = ColumnsHashing::HashMethodSingleLowCardinalityColumn<BaseState, true>;
using State = ColumnsHashing::HashMethodSingleLowCardinalityColumn<BaseState, Mapped, true>;
static const bool low_cardinality_optimization = true;
......@@ -277,7 +277,7 @@ struct AggregationMethodKeysFixed
template <typename Other>
AggregationMethodKeysFixed(const Other & other) : data(other.data) {}
using State = ColumnsHashing::HashMethodKeysFixed<Data, has_nullable_keys, has_low_cardinality>;
using State = ColumnsHashing::HashMethodKeysFixed<typename Data::value_type, Key, Mapped, has_nullable_keys, has_low_cardinality>;
static const bool low_cardinality_optimization = false;
......@@ -355,7 +355,7 @@ struct AggregationMethodSerialized
template <typename Other>
AggregationMethodSerialized(const Other & other) : data(other.data) {}
using State = ColumnsHashing::HashMethodSerialized<Data>;
using State = ColumnsHashing::HashMethodSerialized<typename Data::value_type, Mapped>;
static const bool low_cardinality_optimization = false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册