提交 3c9cf7d4 编写于 作者: A Alexey Milovidov

Merge branch 'master' of github.com:yandex/ClickHouse

...@@ -120,6 +120,11 @@ public: ...@@ -120,6 +120,11 @@ public:
nested_func->insertResultInto(place, to); nested_func->insertResultInto(place, to);
} }
bool allocatesMemoryInArena() const override
{
return nested_func->allocatesMemoryInArena();
}
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena)
{ {
static_cast<const AggregateFunctionArray &>(*that).add(place, columns, row_num, arena); static_cast<const AggregateFunctionArray &>(*that).add(place, columns, row_num, arena);
......
...@@ -103,6 +103,11 @@ public: ...@@ -103,6 +103,11 @@ public:
nested_func->insertResultInto(place, to); nested_func->insertResultInto(place, to);
} }
bool allocatesMemoryInArena() const override
{
return nested_func->allocatesMemoryInArena();
}
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena)
{ {
static_cast<const AggregateFunctionIf &>(*that).add(place, columns, row_num, arena); static_cast<const AggregateFunctionIf &>(*that).add(place, columns, row_num, arena);
......
...@@ -104,6 +104,11 @@ public: ...@@ -104,6 +104,11 @@ public:
nested_func->insertResultInto(place, to); nested_func->insertResultInto(place, to);
} }
bool allocatesMemoryInArena() const override
{
return nested_func->allocatesMemoryInArena();
}
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena)
{ {
static_cast<const AggregateFunctionMerge &>(*that).add(place, columns, row_num, arena); static_cast<const AggregateFunctionMerge &>(*that).add(place, columns, row_num, arena);
......
...@@ -147,6 +147,11 @@ public: ...@@ -147,6 +147,11 @@ public:
to_concrete.insertDefault(); to_concrete.insertDefault();
} }
} }
bool allocatesMemoryInArena() const override
{
return nested_function->allocatesMemoryInArena();
}
}; };
...@@ -255,6 +260,11 @@ public: ...@@ -255,6 +260,11 @@ public:
nested_function->add(nestedPlace(place), nested_columns, row_num, arena); nested_function->add(nestedPlace(place), nested_columns, row_num, arena);
} }
bool allocatesMemoryInArena() const override
{
return nested_function->allocatesMemoryInArena();
}
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, static void addFree(const IAggregateFunction * that, AggregateDataPtr place,
const IColumn ** columns, size_t row_num, Arena * arena) const IColumn ** columns, size_t row_num, Arena * arena)
{ {
......
...@@ -100,6 +100,11 @@ public: ...@@ -100,6 +100,11 @@ public:
/// Аггрегатная функция или состояние аггрегатной функции. /// Аггрегатная функция или состояние аггрегатной функции.
bool isState() const override { return true; } bool isState() const override { return true; }
bool allocatesMemoryInArena() const override
{
return nested_func->allocatesMemoryInArena();
}
AggregateFunctionPtr getNestedFunction() const { return nested_func_owner; } AggregateFunctionPtr getNestedFunction() const { return nested_func_owner; }
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena)
......
...@@ -160,11 +160,21 @@ public: ...@@ -160,11 +160,21 @@ public:
insertMergeFrom(src, n); insertMergeFrom(src, n);
} }
void insertFrom(ConstAggregateDataPtr place)
{
insertDefault();
insertMergeFrom(place);
}
/// Merge state at last row with specified state in another column. /// Merge state at last row with specified state in another column.
void insertMergeFrom(ConstAggregateDataPtr place)
{
func->merge(getData().back(), place, &createOrGetArena());
}
void insertMergeFrom(const IColumn & src, size_t n) void insertMergeFrom(const IColumn & src, size_t n)
{ {
Arena & arena = createOrGetArena(); insertMergeFrom(static_cast<const ColumnAggregateFunction &>(src).getData()[n]);
func->merge(getData().back(), static_cast<const ColumnAggregateFunction &>(src).getData()[n], &arena);
} }
Arena & createOrGetArena() Arena & createOrGetArena()
...@@ -206,10 +216,7 @@ public: ...@@ -206,10 +216,7 @@ public:
throw Exception("Method deserializeAndInsertFromArena is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); throw Exception("Method deserializeAndInsertFromArena is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
} }
void updateHashWithValue(size_t n, SipHash & hash) const override void updateHashWithValue(size_t n, SipHash & hash) const override;
{
throw Exception("Method updateHashWithValue is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
size_t byteSize() const override; size_t byteSize() const override;
......
...@@ -26,6 +26,22 @@ public: ...@@ -26,6 +26,22 @@ public:
bool isConst() const override { return true; } bool isConst() const override { return true; }
virtual ColumnPtr convertToFullColumn() const = 0; virtual ColumnPtr convertToFullColumn() const = 0;
ColumnPtr convertToFullColumnIfConst() const override { return convertToFullColumn(); } ColumnPtr convertToFullColumnIfConst() const override { return convertToFullColumn(); }
Columns scatter(ColumnIndex num_columns, const Selector & selector) const override
{
if (size() != selector.size())
throw Exception("Size of selector doesn't match size of column.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
std::vector<size_t> counts(num_columns);
for (auto idx : selector)
++counts[idx];
Columns res(num_columns);
for (size_t i = 0; i < num_columns; ++i)
res[i] = cloneResized(counts[i]);
return res;
}
}; };
...@@ -158,22 +174,6 @@ public: ...@@ -158,22 +174,6 @@ public:
return std::make_shared<Derived>(replicated_size, data, data_type); return std::make_shared<Derived>(replicated_size, data, data_type);
} }
Columns scatter(ColumnIndex num_columns, const Selector & selector) const override
{
if (s != selector.size())
throw Exception("Size of selector doesn't match size of column.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
std::vector<size_t> counts(num_columns);
for (auto idx : selector)
++counts[idx];
Columns res(num_columns);
for (size_t i = 0; i < num_columns; ++i)
res[i] = cloneResized(counts[i]);
return res;
}
size_t byteSize() const override { return sizeof(data) + sizeof(s); } size_t byteSize() const override { return sizeof(data) + sizeof(s); }
size_t allocatedSize() const override { return byteSize(); } size_t allocatedSize() const override { return byteSize(); }
......
#pragma once
#include <DB/Columns/ColumnConst.h>
#include <DB/DataTypes/DataTypeAggregateFunction.h>
namespace DB
{
class ColumnConstAggregateFunction : public IColumnConst
{
public:
ColumnConstAggregateFunction(size_t size, const Field & value_, const DataTypePtr & data_type_)
: data_type(data_type_), value(value_), s(size)
{
}
String getName() const override
{
return "ColumnConstAggregateFunction";
}
bool isConst() const override
{
return true;
}
ColumnPtr convertToFullColumnIfConst() const override
{
return convertToFullColumn();
}
ColumnPtr convertToFullColumn() const override
{
auto res = std::make_shared<ColumnAggregateFunction>(getAggregateFunction());
for (size_t i = 0; i < s; ++i)
res->insert(value);
return res;
}
ColumnPtr cloneResized(size_t new_size) const override
{
return std::make_shared<ColumnConstAggregateFunction>(new_size, value, data_type);
}
size_t size() const override
{
return s;
}
Field operator[](size_t n) const override
{
/// NOTE: there are no out of bounds check (like in ColumnConstBase)
return value;
}
void get(size_t n, Field & res) const override
{
res = value;
}
StringRef getDataAt(size_t n) const override
{
return value.get<const String &>();
}
void insert(const Field & x) override
{
/// NOTE: Cannot check source function of x
if (value != x)
throw Exception("Cannot insert different element into constant column " + getName(),
ErrorCodes::CANNOT_INSERT_ELEMENT_INTO_CONSTANT_COLUMN);
++s;
}
void insertRangeFrom(const IColumn & src, size_t start, size_t length) override
{
if (!equalsFuncAndValue(src))
throw Exception("Cannot insert different element into constant column " + getName(),
ErrorCodes::CANNOT_INSERT_ELEMENT_INTO_CONSTANT_COLUMN);
s += length;
}
void insertData(const char * pos, size_t length) override
{
throw Exception("Method insertData is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void insertDefault() override
{
++s;
}
void popBack(size_t n) override
{
s -= n;
}
StringRef serializeValueIntoArena(size_t n, Arena & arena, char const *& begin) const override
{
throw Exception("Method serializeValueIntoArena is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
const char * deserializeAndInsertFromArena(const char * pos) override
{
throw Exception("Method deserializeAndInsertFromArena is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void updateHashWithValue(size_t n, SipHash & hash) const override
{
throw Exception("Method updateHashWithValue is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint) const override
{
if (s != filt.size())
throw Exception("Size of filter doesn't match size of column.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
return std::make_shared<ColumnConstAggregateFunction>(countBytesInFilter(filt), value, data_type);
}
ColumnPtr permute(const Permutation & perm, size_t limit) const override
{
if (limit == 0)
limit = s;
else
limit = std::min(s, limit);
if (perm.size() < limit)
throw Exception("Size of permutation is less than required.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
return std::make_shared<ColumnConstAggregateFunction>(limit, value, data_type);
}
int compareAt(size_t n, size_t m, const IColumn & rhs_, int nan_direction_hint) const override
{
return 0;
}
void getPermutation(bool reverse, size_t limit, Permutation & res) const override
{
res.resize(s);
for (size_t i = 0; i < s; ++i)
res[i] = i;
}
ColumnPtr replicate(const Offsets_t & offsets) const override
{
if (s != offsets.size())
throw Exception("Size of offsets doesn't match size of column.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
size_t replicated_size = 0 == s ? 0 : offsets.back();
return std::make_shared<ColumnConstAggregateFunction>(replicated_size, value, data_type);
}
void getExtremes(Field & min, Field & max) const override
{
min = value;
max = value;
}
size_t byteSize() const override
{
return sizeof(value) + sizeof(s);
}
size_t allocatedSize() const override
{
return byteSize();
}
private:
DataTypePtr data_type;
Field value;
size_t s;
AggregateFunctionPtr getAggregateFunction() const
{
return typeid_cast<const DataTypeAggregateFunction &>(*data_type).getFunction();
}
bool equalsFuncAndValue(const IColumn & rhs) const
{
auto rhs_const = dynamic_cast<const ColumnConstAggregateFunction *>(&rhs);
return rhs_const && value == rhs_const->value && data_type->equals(*rhs_const->data_type);
}
};
}
此差异已折叠。
...@@ -462,4 +462,13 @@ void swap(PODArray<T, INITIAL_SIZE, TAllocator, pad_right_> & lhs, PODArray<T, I ...@@ -462,4 +462,13 @@ void swap(PODArray<T, INITIAL_SIZE, TAllocator, pad_right_> & lhs, PODArray<T, I
template <typename T, size_t INITIAL_SIZE = 4096, typename TAllocator = Allocator<false>> template <typename T, size_t INITIAL_SIZE = 4096, typename TAllocator = Allocator<false>>
using PaddedPODArray = PODArray<T, INITIAL_SIZE, TAllocator, 15>; using PaddedPODArray = PODArray<T, INITIAL_SIZE, TAllocator, 15>;
inline constexpr size_t integerRound(size_t value, size_t dividend)
{
return ((value + dividend - 1) / dividend) * dividend;
}
template <typename T, size_t stack_size_in_bytes>
using PODArrayWithStackMemory = PODArray<T, 0, AllocatorWithStackMemory<Allocator<false>, integerRound(stack_size_in_bytes, sizeof(T))>>;
} }
...@@ -32,6 +32,7 @@ public: ...@@ -32,6 +32,7 @@ public:
} }
std::string getFunctionName() const { return function->getName(); } std::string getFunctionName() const { return function->getName(); }
AggregateFunctionPtr getFunction() const { return function; }
std::string getName() const override; std::string getName() const override;
...@@ -62,10 +63,7 @@ public: ...@@ -62,10 +63,7 @@ public:
ColumnPtr createColumn() const override; ColumnPtr createColumn() const override;
ColumnPtr createConstColumn(size_t size, const Field & field) const override; ColumnPtr createConstColumn(size_t size, const Field & field) const override;
Field getDefault() const override Field getDefault() const override;
{
throw Exception("There is no default value for AggregateFunction data type", ErrorCodes::THERE_IS_NO_DEFAULT_VALUE);
}
}; };
......
...@@ -141,6 +141,12 @@ public: ...@@ -141,6 +141,12 @@ public:
throw Exception("getSizeOfField() method is not implemented for data type " + getName(), ErrorCodes::NOT_IMPLEMENTED); throw Exception("getSizeOfField() method is not implemented for data type " + getName(), ErrorCodes::NOT_IMPLEMENTED);
} }
/// Checks that two instances belong to the same type
inline bool equals(const IDataType & rhs) const
{
return getName() == rhs.getName();
}
virtual ~IDataType() {} virtual ~IDataType() {}
}; };
......
...@@ -229,10 +229,11 @@ private: ...@@ -229,10 +229,11 @@ private:
Attribute & getAttribute(const std::string & attribute_name) const; Attribute & getAttribute(const std::string & attribute_name) const;
struct FindResult { struct FindResult
{
const size_t cell_idx;
const bool valid; const bool valid;
const bool outdated; const bool outdated;
const size_t cell_idx;
}; };
FindResult findCellIdx(const Key & id, const CellMetadata::time_point_t now) const; FindResult findCellIdx(const Key & id, const CellMetadata::time_point_t now) const;
...@@ -244,13 +245,13 @@ private: ...@@ -244,13 +245,13 @@ private:
mutable Poco::RWLock rw_lock; mutable Poco::RWLock rw_lock;
// Actual size will be increased to match power of 2 /// Actual size will be increased to match power of 2
const std::size_t size; const std::size_t size;
// all bits to 1 mask (size - 1) (0b1000 - 1 = 0b111) /// all bits to 1 mask (size - 1) (0b1000 - 1 = 0b111)
const std::size_t size_overlap_mask; const std::size_t size_overlap_mask;
// Max tries to find cell, overlaped with mask: if size = 16 and start_cell=10: will try cells: 10,11,12,13,14,15,0,1,2,3 /// Max tries to find cell, overlaped with mask: if size = 16 and start_cell=10: will try cells: 10,11,12,13,14,15,0,1,2,3
static constexpr std::size_t max_collision_length = 10; static constexpr std::size_t max_collision_length = 10;
const UInt64 zero_cell_idx{getCellIdx(0)}; const UInt64 zero_cell_idx{getCellIdx(0)};
......
...@@ -257,6 +257,20 @@ private: ...@@ -257,6 +257,20 @@ private:
static StringRef copyIntoArena(StringRef src, Arena & arena); static StringRef copyIntoArena(StringRef src, Arena & arena);
StringRef copyKey(const StringRef key) const; StringRef copyKey(const StringRef key) const;
struct FindResult
{
const size_t cell_idx;
const bool valid;
const bool outdated;
};
FindResult findCellIdx(const StringRef & key, const CellMetadata::time_point_t now, const size_t hash) const;
FindResult findCellIdx(const StringRef & key, const CellMetadata::time_point_t now) const
{
const auto hash = StringRefHash{}(key);
return findCellIdx(key, now, hash);
};
const std::string name; const std::string name;
const DictionaryStructure dict_struct; const DictionaryStructure dict_struct;
const DictionarySourcePtr source_ptr; const DictionarySourcePtr source_ptr;
...@@ -264,7 +278,16 @@ private: ...@@ -264,7 +278,16 @@ private:
const std::string key_description{dict_struct.getKeyDescription()}; const std::string key_description{dict_struct.getKeyDescription()};
mutable Poco::RWLock rw_lock; mutable Poco::RWLock rw_lock;
/// Actual size will be increased to match power of 2
const std::size_t size; const std::size_t size;
/// all bits to 1 mask (size - 1) (0b1000 - 1 = 0b111)
const std::size_t size_overlap_mask;
/// Max tries to find cell, overlaped with mask: if size = 16 and start_cell=10: will try cells: 10,11,12,13,14,15,0,1,2,3
static constexpr std::size_t max_collision_length = 10;
const UInt64 zero_cell_idx{getCellIdx(StringRef{})}; const UInt64 zero_cell_idx{getCellIdx(StringRef{})};
std::map<std::string, std::size_t> attribute_index_by_name; std::map<std::string, std::size_t> attribute_index_by_name;
mutable std::vector<Attribute> attributes; mutable std::vector<Attribute> attributes;
......
...@@ -1498,19 +1498,11 @@ private: ...@@ -1498,19 +1498,11 @@ private:
}; };
} }
/// Only trivial NULL -> NULL case WrapperType createIdentityWrapper(const DataTypePtr &)
WrapperType createNullWrapper(const DataTypePtr & from_type, const DataTypeNull * to_type)
{ {
if (!typeid_cast<const DataTypeNull *>(from_type.get()))
throw Exception("Conversion from " + from_type->getName() + " to " + to_type->getName() + " is not supported",
ErrorCodes::CANNOT_CONVERT_TYPE);
return [] (Block & block, const ColumnNumbers & arguments, const size_t result) return [] (Block & block, const ColumnNumbers & arguments, const size_t result)
{ {
// just copy pointer to Null column block.safeGetByPosition(result).column = block.safeGetByPosition(arguments.front()).column;
ColumnWithTypeAndName & res_col = block.safeGetByPosition(result);
const ColumnWithTypeAndName & src_col = block.safeGetByPosition(arguments.front());
res_col.column = src_col.column;
}; };
} }
...@@ -1602,7 +1594,9 @@ private: ...@@ -1602,7 +1594,9 @@ private:
WrapperType prepareImpl(const DataTypePtr & from_type, const IDataType * const to_type) WrapperType prepareImpl(const DataTypePtr & from_type, const IDataType * const to_type)
{ {
if (const auto to_actual_type = typeid_cast<const DataTypeUInt8 *>(to_type)) if (from_type->equals(*to_type))
return createIdentityWrapper(from_type);
else if (const auto to_actual_type = typeid_cast<const DataTypeUInt8 *>(to_type))
return createWrapper(from_type, to_actual_type); return createWrapper(from_type, to_actual_type);
else if (const auto to_actual_type = typeid_cast<const DataTypeUInt16 *>(to_type)) else if (const auto to_actual_type = typeid_cast<const DataTypeUInt16 *>(to_type))
return createWrapper(from_type, to_actual_type); return createWrapper(from_type, to_actual_type);
...@@ -1638,8 +1632,6 @@ private: ...@@ -1638,8 +1632,6 @@ private:
return createEnumWrapper(from_type, type_enum); return createEnumWrapper(from_type, type_enum);
else if (const auto type_enum = typeid_cast<const DataTypeEnum16 *>(to_type)) else if (const auto type_enum = typeid_cast<const DataTypeEnum16 *>(to_type))
return createEnumWrapper(from_type, type_enum); return createEnumWrapper(from_type, type_enum);
else if (const auto type_null = typeid_cast<const DataTypeNull *>(to_type))
return createNullWrapper(from_type, type_null);
/// It's possible to use ConvertImplGenericFromString to convert from String to AggregateFunction, /// It's possible to use ConvertImplGenericFromString to convert from String to AggregateFunction,
/// but it is disabled because deserializing aggregate functions state might be unsafe. /// but it is disabled because deserializing aggregate functions state might be unsafe.
...@@ -1691,7 +1683,7 @@ private: ...@@ -1691,7 +1683,7 @@ private:
else if (const auto type = typeid_cast<const DataTypeEnum16 *>(to_type)) else if (const auto type = typeid_cast<const DataTypeEnum16 *>(to_type))
monotonicity_for_range = monotonicityForType(type); monotonicity_for_range = monotonicityForType(type);
} }
/// other types like FixedString, Array and Tuple have no monotonicity defined /// other types like Null, FixedString, Array and Tuple have no monotonicity defined
} }
public: public:
......
...@@ -15,6 +15,6 @@ class IDataType; ...@@ -15,6 +15,6 @@ class IDataType;
* Проверяет совместимость типов, проверяет попадание значений в диапазон допустимых значений типа, делает преобразование типа. * Проверяет совместимость типов, проверяет попадание значений в диапазон допустимых значений типа, делает преобразование типа.
* Если значение не попадает в диапазон - возвращает Null. * Если значение не попадает в диапазон - возвращает Null.
*/ */
Field convertFieldToType(const Field & src, const IDataType & type); Field convertFieldToType(const Field & from_value, const IDataType & to_type, const IDataType * from_type_hint = nullptr);
} }
...@@ -9,12 +9,13 @@ namespace DB ...@@ -9,12 +9,13 @@ namespace DB
class IAST; class IAST;
class Context; class Context;
class IDataType;
/** Evaluate constant expression. /** Evaluate constant expression and its type.
* Used in rare cases - for elements of set for IN, for data to INSERT. * Used in rare cases - for elements of set for IN, for data to INSERT.
* Quite suboptimal. * Quite suboptimal.
*/ */
Field evaluateConstantExpression(std::shared_ptr<IAST> & node, const Context & context); std::pair<Field, std::shared_ptr<IDataType>> evaluateConstantExpression(std::shared_ptr<IAST> & node, const Context & context);
/** Evaluate constant expression /** Evaluate constant expression
......
#include <DB/AggregateFunctions/AggregateFunctionState.h> #include <DB/AggregateFunctions/AggregateFunctionState.h>
#include <DB/Columns/ColumnAggregateFunction.h> #include <DB/Columns/ColumnAggregateFunction.h>
#include <DB/Common/SipHash.h>
namespace DB namespace DB
{ {
...@@ -142,6 +142,16 @@ ColumnPtr ColumnAggregateFunction::permute(const Permutation & perm, size_t limi ...@@ -142,6 +142,16 @@ ColumnPtr ColumnAggregateFunction::permute(const Permutation & perm, size_t limi
return res; return res;
} }
/// Is required to support operations with Set
void ColumnAggregateFunction::updateHashWithValue(size_t n, SipHash & hash) const
{
String buf;
{
WriteBufferFromString wbuf(buf);
func->serialize(getData()[n], wbuf);
}
hash.update(buf.c_str(), buf.size());
}
/// NOTE: Highly overestimates size of a column if it was produced in AggregatingBlockInputStream (it contains size of other columns) /// NOTE: Highly overestimates size of a column if it was produced in AggregatingBlockInputStream (it contains size of other columns)
size_t ColumnAggregateFunction::byteSize() const size_t ColumnAggregateFunction::byteSize() const
......
...@@ -109,7 +109,8 @@ bool ValuesRowInputStream::read(Block & block) ...@@ -109,7 +109,8 @@ bool ValuesRowInputStream::read(Block & block)
istr.position() = const_cast<char *>(max_parsed_pos); istr.position() = const_cast<char *>(max_parsed_pos);
Field value = convertFieldToType(evaluateConstantExpression(ast, context), type); std::pair<Field, DataTypePtr> value_raw = evaluateConstantExpression(ast, context);
Field value = convertFieldToType(value_raw.first, type, value_raw.second.get());
if (value.isNull()) if (value.isNull())
{ {
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#include <DB/IO/WriteHelpers.h> #include <DB/IO/WriteHelpers.h>
#include <DB/IO/ReadHelpers.h> #include <DB/IO/ReadHelpers.h>
#include <DB/Columns/ColumnConst.h>
#include <DB/Columns/ColumnAggregateFunction.h> #include <DB/Columns/ColumnAggregateFunction.h>
#include <DB/Columns/ColumnConstAggregateFunction.h>
#include <DB/DataTypes/DataTypeAggregateFunction.h> #include <DB/DataTypes/DataTypeAggregateFunction.h>
...@@ -29,8 +29,8 @@ std::string DataTypeAggregateFunction::getName() const ...@@ -29,8 +29,8 @@ std::string DataTypeAggregateFunction::getName() const
stream << ")"; stream << ")";
} }
for (DataTypes::const_iterator it = argument_types.begin(); it != argument_types.end(); ++it) for (const auto & argument_type: argument_types)
stream << ", " << (*it)->getName(); stream << ", " << argument_type->getName();
stream << ")"; stream << ")";
return stream.str(); return stream.str();
...@@ -236,7 +236,33 @@ ColumnPtr DataTypeAggregateFunction::createColumn() const ...@@ -236,7 +236,33 @@ ColumnPtr DataTypeAggregateFunction::createColumn() const
ColumnPtr DataTypeAggregateFunction::createConstColumn(size_t size, const Field & field) const ColumnPtr DataTypeAggregateFunction::createConstColumn(size_t size, const Field & field) const
{ {
throw Exception("Const column with aggregate function is not supported", ErrorCodes::NOT_IMPLEMENTED); return std::make_shared<ColumnConstAggregateFunction>(size, field, clone());
}
/// Create empty state
Field DataTypeAggregateFunction::getDefault() const
{
Field field = String();
PODArrayWithStackMemory<char, 16> place_buffer(function->sizeOfData());
AggregateDataPtr place = place_buffer.data();
function->create(place);
try
{
WriteBufferFromString buffer_from_field(field.get<String &>());
function->serialize(place, buffer_from_field);
}
catch (...)
{
function->destroy(place);
throw;
}
function->destroy(place);
return field;
} }
......
...@@ -25,7 +25,7 @@ namespace ErrorCodes ...@@ -25,7 +25,7 @@ namespace ErrorCodes
inline UInt64 CacheDictionary::getCellIdx(const Key id) const inline UInt64 CacheDictionary::getCellIdx(const Key id) const
{ {
const auto hash = intHash64(id); const auto hash = intHash64(id);
const auto idx = hash & (size - 1); const auto idx = hash & size_overlap_mask;
return idx; return idx;
} }
...@@ -175,7 +175,7 @@ void CacheDictionary::getString( ...@@ -175,7 +175,7 @@ void CacheDictionary::getString(
} }
/// returns 'cell is valid' flag, 'cell is outdated' flag, cell_idx /// returns cell_idx (always valid for replacing), 'cell is valid' flag, 'cell is outdated' flag
/// true false found and valid /// true false found and valid
/// false true not found (something outdated, maybe our cell) /// false true not found (something outdated, maybe our cell)
/// false false not found (other id stored with valid data) /// false false not found (other id stored with valid data)
...@@ -206,13 +206,13 @@ CacheDictionary::FindResult CacheDictionary::findCellIdx(const Key & id, const C ...@@ -206,13 +206,13 @@ CacheDictionary::FindResult CacheDictionary::findCellIdx(const Key & id, const C
if (cell.expiresAt() < now) if (cell.expiresAt() < now)
{ {
return {false, true, cell_idx}; return {cell_idx, false, true};
} }
return {true, false, cell_idx}; return {cell_idx, true, false};
} }
return {false, false, oldest_id}; return {oldest_id, false, false};
} }
void CacheDictionary::has(const PaddedPODArray<Key> & ids, PaddedPODArray<UInt8> & out) const void CacheDictionary::has(const PaddedPODArray<Key> & ids, PaddedPODArray<UInt8> & out) const
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#include <DB/Common/ProfilingScopedRWLock.h> #include <DB/Common/ProfilingScopedRWLock.h>
#include <ext/range.hpp> #include <ext/range.hpp>
namespace DB namespace DB
{ {
...@@ -20,7 +19,7 @@ namespace ErrorCodes ...@@ -20,7 +19,7 @@ namespace ErrorCodes
inline UInt64 ComplexKeyCacheDictionary::getCellIdx(const StringRef key) const inline UInt64 ComplexKeyCacheDictionary::getCellIdx(const StringRef key) const
{ {
const auto hash = StringRefHash{}(key); const auto hash = StringRefHash{}(key);
const auto idx = hash & (size - 1); const auto idx = hash & size_overlap_mask;
return idx; return idx;
} }
...@@ -29,7 +28,9 @@ ComplexKeyCacheDictionary::ComplexKeyCacheDictionary(const std::string & name, c ...@@ -29,7 +28,9 @@ ComplexKeyCacheDictionary::ComplexKeyCacheDictionary(const std::string & name, c
DictionarySourcePtr source_ptr, const DictionaryLifetime dict_lifetime, DictionarySourcePtr source_ptr, const DictionaryLifetime dict_lifetime,
const size_t size) const size_t size)
: name{name}, dict_struct(dict_struct), source_ptr{std::move(source_ptr)}, dict_lifetime(dict_lifetime), : name{name}, dict_struct(dict_struct), source_ptr{std::move(source_ptr)}, dict_lifetime(dict_lifetime),
size{roundUpToPowerOfTwoOrZero(size)}, rnd_engine{randomSeed()} size{roundUpToPowerOfTwoOrZero(std::max(size, size_t(max_collision_length)))},
size_overlap_mask{this->size - 1},
rnd_engine{randomSeed()}
{ {
if (!this->source_ptr->supportsSelectiveLoad()) if (!this->source_ptr->supportsSelectiveLoad())
throw Exception{ throw Exception{
...@@ -174,6 +175,52 @@ void ComplexKeyCacheDictionary::getString( ...@@ -174,6 +175,52 @@ void ComplexKeyCacheDictionary::getString(
getItemsString(attribute, key_columns, out, [&] (const size_t) { return StringRef{def}; }); getItemsString(attribute, key_columns, out, [&] (const size_t) { return StringRef{def}; });
} }
/// returns cell_idx (always valid for replacing), 'cell is valid' flag, 'cell is outdated' flag,
/// true false found and valid
/// false true not found (something outdated, maybe our cell)
/// false false not found (other id stored with valid data)
/// true true impossible
///
/// todo: split this func to two: find_for_get and find_for_set
ComplexKeyCacheDictionary::FindResult ComplexKeyCacheDictionary::findCellIdx(const StringRef & key, const CellMetadata::time_point_t now, const size_t hash) const
{
auto pos = hash;
auto oldest_id = pos;
auto oldest_time = CellMetadata::time_point_t::max();
const auto stop = pos + max_collision_length;
for (; pos < stop; ++pos)
{
const auto cell_idx = pos & size_overlap_mask;
const auto & cell = cells[cell_idx];
if (cell.hash != hash || cell.key != key)
{
/// maybe we already found nearest expired cell
if (oldest_time > now && oldest_time > cell.expiresAt())
{
oldest_time = cell.expiresAt();
oldest_id = cell_idx;
}
continue;
}
if (cell.expiresAt() < now)
{
return {cell_idx, false, true};
}
return {cell_idx, true, false};
}
oldest_id &= size_overlap_mask;
return {oldest_id, false, false};
}
void ComplexKeyCacheDictionary::has(const ConstColumnPlainPtrs & key_columns, const DataTypes & key_types, PaddedPODArray<UInt8> & out) const void ComplexKeyCacheDictionary::has(const ConstColumnPlainPtrs & key_columns, const DataTypes & key_types, PaddedPODArray<UInt8> & out) const
{ {
dict_struct.validateKeyTypes(key_types); dict_struct.validateKeyTypes(key_types);
...@@ -181,11 +228,12 @@ void ComplexKeyCacheDictionary::has(const ConstColumnPlainPtrs & key_columns, co ...@@ -181,11 +228,12 @@ void ComplexKeyCacheDictionary::has(const ConstColumnPlainPtrs & key_columns, co
/// Mapping: <key> -> { all indices `i` of `key_columns` such that `key_columns[i]` = <key> } /// Mapping: <key> -> { all indices `i` of `key_columns` such that `key_columns[i]` = <key> }
MapType<std::vector<size_t>> outdated_keys; MapType<std::vector<size_t>> outdated_keys;
const auto rows = key_columns.front()->size();
const auto rows_num = key_columns.front()->size();
const auto keys_size = dict_struct.key.value().size(); const auto keys_size = dict_struct.key.value().size();
StringRefs keys(keys_size); StringRefs keys(keys_size);
Arena temporary_keys_pool; Arena temporary_keys_pool;
PODArray<StringRef> keys_array(rows); PODArray<StringRef> keys_array(rows_num);
size_t cache_expired = 0, cache_not_found = 0, cache_hit = 0; size_t cache_expired = 0, cache_not_found = 0, cache_hit = 0;
{ {
...@@ -193,31 +241,28 @@ void ComplexKeyCacheDictionary::has(const ConstColumnPlainPtrs & key_columns, co ...@@ -193,31 +241,28 @@ void ComplexKeyCacheDictionary::has(const ConstColumnPlainPtrs & key_columns, co
const auto now = std::chrono::system_clock::now(); const auto now = std::chrono::system_clock::now();
/// fetch up-to-date values, decide which ones require update /// fetch up-to-date values, decide which ones require update
for (const auto row : ext::range(0, rows)) for (const auto row : ext::range(0, rows_num))
{ {
const StringRef key = placeKeysInPool(row, key_columns, keys, temporary_keys_pool); const StringRef key = placeKeysInPool(row, key_columns, keys, temporary_keys_pool);
keys_array[row] = key; keys_array[row] = key;
const auto hash = StringRefHash{}(key); const auto find_result = findCellIdx(key, now);
const size_t cell_idx = hash & (size - 1); const auto & cell_idx = find_result.cell_idx;
const auto & cell = cells[cell_idx];
/** cell should be updated if either: /** cell should be updated if either:
* 1. keys (or hash) do not match, * 1. keys (or hash) do not match,
* 2. cell has expired, * 2. cell has expired,
* 3. explicit defaults were specified and cell was set default. */ * 3. explicit defaults were specified and cell was set default. */
if (cell.hash != hash || cell.key != key) if (!find_result.valid)
{
++cache_not_found;
outdated_keys[key].push_back(row);
}
else if (cell.expiresAt() < now)
{ {
++cache_expired;
outdated_keys[key].push_back(row); outdated_keys[key].push_back(row);
if (find_result.outdated)
++cache_expired;
else
++cache_not_found;
} }
else else
{ {
++cache_hit; ++cache_hit;
const auto & cell = cells[cell_idx];
out[row] = !cell.isDefault(); out[row] = !cell.isDefault();
} }
} }
...@@ -226,8 +271,8 @@ void ComplexKeyCacheDictionary::has(const ConstColumnPlainPtrs & key_columns, co ...@@ -226,8 +271,8 @@ void ComplexKeyCacheDictionary::has(const ConstColumnPlainPtrs & key_columns, co
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found); ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found);
ProfileEvents::increment(ProfileEvents::DictCacheKeysHit, cache_hit); ProfileEvents::increment(ProfileEvents::DictCacheKeysHit, cache_hit);
query_count.fetch_add(rows, std::memory_order_relaxed); query_count.fetch_add(rows_num, std::memory_order_relaxed);
hit_count.fetch_add(rows - outdated_keys.size(), std::memory_order_release); hit_count.fetch_add(rows_num - outdated_keys.size(), std::memory_order_release);
if (outdated_keys.empty()) if (outdated_keys.empty())
return; return;
...@@ -376,11 +421,11 @@ void ComplexKeyCacheDictionary::getItemsNumberImpl( ...@@ -376,11 +421,11 @@ void ComplexKeyCacheDictionary::getItemsNumberImpl(
MapType<std::vector<size_t>> outdated_keys; MapType<std::vector<size_t>> outdated_keys;
auto & attribute_array = std::get<ContainerPtrType<AttributeType>>(attribute.arrays); auto & attribute_array = std::get<ContainerPtrType<AttributeType>>(attribute.arrays);
const auto rows = key_columns.front()->size(); const auto rows_num = key_columns.front()->size();
const auto keys_size = dict_struct.key.value().size(); const auto keys_size = dict_struct.key.value().size();
StringRefs keys(keys_size); StringRefs keys(keys_size);
Arena temporary_keys_pool; Arena temporary_keys_pool;
PODArray<StringRef> keys_array(rows); PODArray<StringRef> keys_array(rows_num);
size_t cache_expired = 0, cache_not_found = 0, cache_hit = 0; size_t cache_expired = 0, cache_not_found = 0, cache_hit = 0;
{ {
...@@ -388,31 +433,30 @@ void ComplexKeyCacheDictionary::getItemsNumberImpl( ...@@ -388,31 +433,30 @@ void ComplexKeyCacheDictionary::getItemsNumberImpl(
const auto now = std::chrono::system_clock::now(); const auto now = std::chrono::system_clock::now();
/// fetch up-to-date values, decide which ones require update /// fetch up-to-date values, decide which ones require update
for (const auto row : ext::range(0, rows)) for (const auto row : ext::range(0, rows_num))
{ {
const StringRef key = placeKeysInPool(row, key_columns, keys, temporary_keys_pool); const StringRef key = placeKeysInPool(row, key_columns, keys, temporary_keys_pool);
keys_array[row] = key; keys_array[row] = key;
const auto hash = StringRefHash{}(key); const auto find_result = findCellIdx(key, now);
const size_t cell_idx = hash & (size - 1);
const auto & cell = cells[cell_idx];
/** cell should be updated if either: /** cell should be updated if either:
* 1. keys (or hash) do not match, * 1. keys (or hash) do not match,
* 2. cell has expired, * 2. cell has expired,
* 3. explicit defaults were specified and cell was set default. */ * 3. explicit defaults were specified and cell was set default. */
if (cell.hash != hash || cell.key != key)
{ if (!find_result.valid)
++cache_not_found;
outdated_keys[key].push_back(row);
}
else if (cell.expiresAt() < now)
{ {
++cache_expired;
outdated_keys[key].push_back(row); outdated_keys[key].push_back(row);
if (find_result.outdated)
++cache_expired;
else
++cache_not_found;
} }
else else
{ {
++cache_hit; ++cache_hit;
const auto & cell_idx = find_result.cell_idx;
const auto & cell = cells[cell_idx];
out[row] = cell.isDefault() ? get_default(row) : attribute_array[cell_idx]; out[row] = cell.isDefault() ? get_default(row) : attribute_array[cell_idx];
} }
} }
...@@ -420,9 +464,8 @@ void ComplexKeyCacheDictionary::getItemsNumberImpl( ...@@ -420,9 +464,8 @@ void ComplexKeyCacheDictionary::getItemsNumberImpl(
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired); ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired);
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found); ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found);
ProfileEvents::increment(ProfileEvents::DictCacheKeysHit, cache_hit); ProfileEvents::increment(ProfileEvents::DictCacheKeysHit, cache_hit);
query_count.fetch_add(rows_num, std::memory_order_relaxed);
query_count.fetch_add(rows, std::memory_order_relaxed); hit_count.fetch_add(rows_num - outdated_keys.size(), std::memory_order_release);
hit_count.fetch_add(rows - outdated_keys.size(), std::memory_order_release);
if (outdated_keys.empty()) if (outdated_keys.empty())
return; return;
...@@ -451,9 +494,9 @@ void ComplexKeyCacheDictionary::getItemsString( ...@@ -451,9 +494,9 @@ void ComplexKeyCacheDictionary::getItemsString(
Attribute & attribute, const ConstColumnPlainPtrs & key_columns, ColumnString * out, Attribute & attribute, const ConstColumnPlainPtrs & key_columns, ColumnString * out,
DefaultGetter && get_default) const DefaultGetter && get_default) const
{ {
const auto rows = key_columns.front()->size(); const auto rows_num = key_columns.front()->size();
/// save on some allocations /// save on some allocations
out->getOffsets().reserve(rows); out->getOffsets().reserve(rows_num);
const auto keys_size = dict_struct.key.value().size(); const auto keys_size = dict_struct.key.value().size();
StringRefs keys(keys_size); StringRefs keys(keys_size);
...@@ -469,21 +512,21 @@ void ComplexKeyCacheDictionary::getItemsString( ...@@ -469,21 +512,21 @@ void ComplexKeyCacheDictionary::getItemsString(
const auto now = std::chrono::system_clock::now(); const auto now = std::chrono::system_clock::now();
/// fetch up-to-date values, discard on fail /// fetch up-to-date values, discard on fail
for (const auto row : ext::range(0, rows)) for (const auto row : ext::range(0, rows_num))
{ {
const StringRef key = placeKeysInPool(row, key_columns, keys, temporary_keys_pool); const StringRef key = placeKeysInPool(row, key_columns, keys, temporary_keys_pool);
SCOPE_EXIT(temporary_keys_pool.rollback(key.size)); SCOPE_EXIT(temporary_keys_pool.rollback(key.size));
const auto hash = StringRefHash{}(key); const auto find_result = findCellIdx(key, now);
const size_t cell_idx = hash & (size - 1);
const auto & cell = cells[cell_idx];
if (cell.hash != hash || cell.key != key || cell.expiresAt() < now) if (!find_result.valid)
{ {
found_outdated_values = true; found_outdated_values = true;
break; break;
} }
else else
{ {
const auto & cell_idx = find_result.cell_idx;
const auto & cell = cells[cell_idx];
const auto string_ref = cell.isDefault() ? get_default(row) : attribute_array[cell_idx]; const auto string_ref = cell.isDefault() ? get_default(row) : attribute_array[cell_idx];
out->insertData(string_ref.data, string_ref.size); out->insertData(string_ref.data, string_ref.size);
} }
...@@ -493,8 +536,8 @@ void ComplexKeyCacheDictionary::getItemsString( ...@@ -493,8 +536,8 @@ void ComplexKeyCacheDictionary::getItemsString(
/// optimistic code completed successfully /// optimistic code completed successfully
if (!found_outdated_values) if (!found_outdated_values)
{ {
query_count.fetch_add(rows, std::memory_order_relaxed); query_count.fetch_add(rows_num, std::memory_order_relaxed);
hit_count.fetch_add(rows, std::memory_order_release); hit_count.fetch_add(rows_num, std::memory_order_release);
return; return;
} }
...@@ -506,7 +549,7 @@ void ComplexKeyCacheDictionary::getItemsString( ...@@ -506,7 +549,7 @@ void ComplexKeyCacheDictionary::getItemsString(
MapType<std::vector<size_t>> outdated_keys; MapType<std::vector<size_t>> outdated_keys;
/// we are going to store every string separately /// we are going to store every string separately
MapType<StringRef> map; MapType<StringRef> map;
PODArray<StringRef> keys_array(rows); PODArray<StringRef> keys_array(rows_num);
size_t total_length = 0; size_t total_length = 0;
size_t cache_expired = 0, cache_not_found = 0, cache_hit = 0; size_t cache_expired = 0, cache_not_found = 0, cache_hit = 0;
...@@ -514,27 +557,25 @@ void ComplexKeyCacheDictionary::getItemsString( ...@@ -514,27 +557,25 @@ void ComplexKeyCacheDictionary::getItemsString(
const ProfilingScopedReadRWLock read_lock{rw_lock, ProfileEvents::DictCacheLockReadNs}; const ProfilingScopedReadRWLock read_lock{rw_lock, ProfileEvents::DictCacheLockReadNs};
const auto now = std::chrono::system_clock::now(); const auto now = std::chrono::system_clock::now();
for (const auto row : ext::range(0, rows)) for (const auto row : ext::range(0, rows_num))
{ {
const StringRef key = placeKeysInPool(row, key_columns, keys, temporary_keys_pool); const StringRef key = placeKeysInPool(row, key_columns, keys, temporary_keys_pool);
keys_array[row] = key; keys_array[row] = key;
const auto hash = StringRefHash{}(key); const auto find_result = findCellIdx(key, now);
const size_t cell_idx = hash & (size - 1);
const auto & cell = cells[cell_idx];
if (cell.hash != hash || cell.key != key) if (!find_result.valid)
{
++cache_not_found;
outdated_keys[key].push_back(row);
}
else if (cell.expiresAt() < now)
{ {
++cache_expired;
outdated_keys[key].push_back(row); outdated_keys[key].push_back(row);
if (find_result.outdated)
++cache_expired;
else
++cache_not_found;
} }
else else
{ {
++cache_hit; ++cache_hit;
const auto & cell_idx = find_result.cell_idx;
const auto & cell = cells[cell_idx];
const auto string_ref = cell.isDefault() ? get_default(row) : attribute_array[cell_idx]; const auto string_ref = cell.isDefault() ? get_default(row) : attribute_array[cell_idx];
if (!cell.isDefault()) if (!cell.isDefault())
...@@ -548,8 +589,8 @@ void ComplexKeyCacheDictionary::getItemsString( ...@@ -548,8 +589,8 @@ void ComplexKeyCacheDictionary::getItemsString(
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found); ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found);
ProfileEvents::increment(ProfileEvents::DictCacheKeysHit, cache_hit); ProfileEvents::increment(ProfileEvents::DictCacheKeysHit, cache_hit);
query_count.fetch_add(rows, std::memory_order_relaxed); query_count.fetch_add(rows_num, std::memory_order_relaxed);
hit_count.fetch_add(rows - outdated_keys.size(), std::memory_order_release); hit_count.fetch_add(rows_num - outdated_keys.size(), std::memory_order_release);
/// request new values /// request new values
if (!outdated_keys.empty()) if (!outdated_keys.empty())
...@@ -614,6 +655,7 @@ void ComplexKeyCacheDictionary::update( ...@@ -614,6 +655,7 @@ void ComplexKeyCacheDictionary::update(
StringRefs keys(keys_size); StringRefs keys(keys_size);
const auto attributes_size = attributes.size(); const auto attributes_size = attributes.size();
const auto now = std::chrono::system_clock::now();
while (const auto block = stream->read()) while (const auto block = stream->read())
{ {
...@@ -632,13 +674,14 @@ void ComplexKeyCacheDictionary::update( ...@@ -632,13 +674,14 @@ void ComplexKeyCacheDictionary::update(
return block.safeGetByPosition(keys_size + attribute_idx).column.get(); return block.safeGetByPosition(keys_size + attribute_idx).column.get();
}); });
const auto rows = block.rows(); const auto rows_num = block.rows();
for (const auto row : ext::range(0, rows)) for (const auto row : ext::range(0, rows_num))
{ {
auto key = allocKey(row, key_columns, keys); auto key = allocKey(row, key_columns, keys);
const auto hash = StringRefHash{}(key); const auto hash = StringRefHash{}(key);
const size_t cell_idx = hash & (size - 1); const auto find_result = findCellIdx(key, now, hash);
const auto & cell_idx = find_result.cell_idx;
auto & cell = cells[cell_idx]; auto & cell = cells[cell_idx];
for (const auto attribute_idx : ext::range(0, attributes.size())) for (const auto attribute_idx : ext::range(0, attributes.size()))
...@@ -691,6 +734,8 @@ void ComplexKeyCacheDictionary::update( ...@@ -691,6 +734,8 @@ void ComplexKeyCacheDictionary::update(
size_t found_num = 0; size_t found_num = 0;
size_t not_found_num = 0; size_t not_found_num = 0;
const auto now = std::chrono::system_clock::now();
/// Check which ids have not been found and require setting null_value /// Check which ids have not been found and require setting null_value
for (const auto key_found_pair : remaining_keys) for (const auto key_found_pair : remaining_keys)
{ {
...@@ -704,7 +749,8 @@ void ComplexKeyCacheDictionary::update( ...@@ -704,7 +749,8 @@ void ComplexKeyCacheDictionary::update(
auto key = key_found_pair.first; auto key = key_found_pair.first;
const auto hash = StringRefHash{}(key); const auto hash = StringRefHash{}(key);
const size_t cell_idx = hash & (size - 1); const auto find_result = findCellIdx(key, now, hash);
const auto & cell_idx = find_result.cell_idx;
auto & cell = cells[cell_idx]; auto & cell = cells[cell_idx];
/// Set null_value for each attribute /// Set null_value for each attribute
......
...@@ -2723,10 +2723,6 @@ void FunctionArrayReduce::getReturnTypeAndPrerequisitesImpl( ...@@ -2723,10 +2723,6 @@ void FunctionArrayReduce::getReturnTypeAndPrerequisitesImpl(
aggregate_function = AggregateFunctionFactory().get(aggregate_function_name, argument_types); aggregate_function = AggregateFunctionFactory().get(aggregate_function_name, argument_types);
/// Потому что владение состояниями агрегатных функций никуда не отдаётся.
if (aggregate_function->isState())
throw Exception("Using aggregate function with -State modifier in function arrayReduce is not supported", ErrorCodes::BAD_ARGUMENTS);
if (has_parameters) if (has_parameters)
aggregate_function->setParameters(params_row); aggregate_function->setParameters(params_row);
aggregate_function->setArguments(argument_types); aggregate_function->setArguments(argument_types);
...@@ -2752,12 +2748,15 @@ void FunctionArrayReduce::executeImpl(Block & block, const ColumnNumbers & argum ...@@ -2752,12 +2748,15 @@ void FunctionArrayReduce::executeImpl(Block & block, const ColumnNumbers & argum
std::vector<const IColumn *> aggregate_arguments_vec(arguments.size() - 1); std::vector<const IColumn *> aggregate_arguments_vec(arguments.size() - 1);
bool is_const = true;
for (size_t i = 0, size = arguments.size() - 1; i < size; ++i) for (size_t i = 0, size = arguments.size() - 1; i < size; ++i)
{ {
const IColumn * col = block.getByPosition(arguments[i + 1]).column.get(); const IColumn * col = block.getByPosition(arguments[i + 1]).column.get();
if (const ColumnArray * arr = typeid_cast<const ColumnArray *>(col)) if (const ColumnArray * arr = typeid_cast<const ColumnArray *>(col))
{ {
aggregate_arguments_vec[i] = arr->getDataPtr().get(); aggregate_arguments_vec[i] = arr->getDataPtr().get();
is_const = false;
} }
else if (const ColumnConstArray * arr = typeid_cast<const ColumnConstArray *>(col)) else if (const ColumnConstArray * arr = typeid_cast<const ColumnConstArray *>(col))
{ {
...@@ -2774,9 +2773,12 @@ void FunctionArrayReduce::executeImpl(Block & block, const ColumnNumbers & argum ...@@ -2774,9 +2773,12 @@ void FunctionArrayReduce::executeImpl(Block & block, const ColumnNumbers & argum
? *materialized_columns.front().get() ? *materialized_columns.front().get()
: *block.getByPosition(arguments[1]).column.get()).getOffsets(); : *block.getByPosition(arguments[1]).column.get()).getOffsets();
ColumnPtr result_holder = block.safeGetByPosition(result).type->createColumn(); ColumnPtr result_holder = block.safeGetByPosition(result).type->createColumn();
block.safeGetByPosition(result).column = result_holder; IColumn & res_col = *result_holder;
IColumn & res_col = *result_holder.get();
/// AggregateFunction's states should be inserted into column using specific way
auto res_col_aggregate_function = typeid_cast<ColumnAggregateFunction *>(&res_col);
ColumnArray::Offset_t current_offset = 0; ColumnArray::Offset_t current_offset = 0;
for (size_t i = 0; i < rows; ++i) for (size_t i = 0; i < rows; ++i)
...@@ -2789,7 +2791,10 @@ void FunctionArrayReduce::executeImpl(Block & block, const ColumnNumbers & argum ...@@ -2789,7 +2791,10 @@ void FunctionArrayReduce::executeImpl(Block & block, const ColumnNumbers & argum
for (size_t j = current_offset; j < next_offset; ++j) for (size_t j = current_offset; j < next_offset; ++j)
agg_func.add(place, aggregate_arguments, j, arena.get()); agg_func.add(place, aggregate_arguments, j, arena.get());
agg_func.insertResultInto(place, res_col); if (!res_col_aggregate_function)
agg_func.insertResultInto(place, res_col);
else
res_col_aggregate_function->insertFrom(place);
} }
catch (...) catch (...)
{ {
...@@ -2800,6 +2805,15 @@ void FunctionArrayReduce::executeImpl(Block & block, const ColumnNumbers & argum ...@@ -2800,6 +2805,15 @@ void FunctionArrayReduce::executeImpl(Block & block, const ColumnNumbers & argum
agg_func.destroy(place); agg_func.destroy(place);
current_offset = next_offset; current_offset = next_offset;
} }
if (!is_const)
{
block.safeGetByPosition(result).column = result_holder;
}
else
{
block.safeGetByPosition(result).column = block.safeGetByPosition(result).type->createConstColumn(rows, res_col[0]);
}
} }
......
...@@ -168,9 +168,14 @@ bool Set::insertFromBlock(const Block & block, bool create_ordered_set) ...@@ -168,9 +168,14 @@ bool Set::insertFromBlock(const Block & block, bool create_ordered_set)
static Field extractValueFromNode(ASTPtr & node, const IDataType & type, const Context & context) static Field extractValueFromNode(ASTPtr & node, const IDataType & type, const Context & context)
{ {
if (ASTLiteral * lit = typeid_cast<ASTLiteral *>(node.get())) if (ASTLiteral * lit = typeid_cast<ASTLiteral *>(node.get()))
{
return convertFieldToType(lit->value, type); return convertFieldToType(lit->value, type);
}
else if (typeid_cast<ASTFunction *>(node.get())) else if (typeid_cast<ASTFunction *>(node.get()))
return convertFieldToType(evaluateConstantExpression(node, context), type); {
std::pair<Field, DataTypePtr> value_raw = evaluateConstantExpression(node, context);
return convertFieldToType(value_raw.first, type, value_raw.second.get());
}
else else
throw Exception("Incorrect element of set. Must be literal or constant expression.", ErrorCodes::INCORRECT_ELEMENT_OF_SET); throw Exception("Incorrect element of set. Must be literal or constant expression.", ErrorCodes::INCORRECT_ELEMENT_OF_SET);
} }
......
...@@ -149,16 +149,19 @@ Field convertFieldToTypeImpl(const Field & src, const IDataType & type) ...@@ -149,16 +149,19 @@ Field convertFieldToTypeImpl(const Field & src, const IDataType & type)
} }
Field convertFieldToType(const Field & src, const IDataType & type) Field convertFieldToType(const Field & from_value, const IDataType & to_type, const IDataType * from_type_hint)
{ {
if (type.isNullable()) if (from_type_hint && from_type_hint->equals(to_type))
return from_value;
if (to_type.isNullable())
{ {
const DataTypeNullable & nullable_type = static_cast<const DataTypeNullable &>(type); const DataTypeNullable & nullable_type = static_cast<const DataTypeNullable &>(to_type);
const DataTypePtr & nested_type = nullable_type.getNestedType(); const DataTypePtr & nested_type = nullable_type.getNestedType();
return convertFieldToTypeImpl(src, *nested_type); return convertFieldToTypeImpl(from_value, *nested_type);
} }
else else
return convertFieldToTypeImpl(src, type); return convertFieldToTypeImpl(from_value, to_type);
} }
......
...@@ -20,7 +20,7 @@ namespace ErrorCodes ...@@ -20,7 +20,7 @@ namespace ErrorCodes
} }
Field evaluateConstantExpression(ASTPtr & node, const Context & context) std::pair<Field, std::shared_ptr<IDataType>> evaluateConstantExpression(std::shared_ptr<IAST> & node, const Context & context)
{ {
ExpressionActionsPtr expr_for_constant_folding = ExpressionAnalyzer( ExpressionActionsPtr expr_for_constant_folding = ExpressionAnalyzer(
node, context, nullptr, NamesAndTypesList{{ "_dummy", std::make_shared<DataTypeUInt8>() }}).getConstActions(); node, context, nullptr, NamesAndTypesList{{ "_dummy", std::make_shared<DataTypeUInt8>() }}).getConstActions();
...@@ -38,12 +38,13 @@ Field evaluateConstantExpression(ASTPtr & node, const Context & context) ...@@ -38,12 +38,13 @@ Field evaluateConstantExpression(ASTPtr & node, const Context & context)
if (!block_with_constants.has(name)) if (!block_with_constants.has(name))
throw Exception("Element of set in IN or VALUES is not a constant expression: " + name, ErrorCodes::BAD_ARGUMENTS); throw Exception("Element of set in IN or VALUES is not a constant expression: " + name, ErrorCodes::BAD_ARGUMENTS);
const IColumn & result_column = *block_with_constants.getByName(name).column; const ColumnWithTypeAndName & result = block_with_constants.getByName(name);
const IColumn & result_column = *result.column;
if (!result_column.isConst()) if (!result_column.isConst())
throw Exception("Element of set in IN or VALUES is not a constant expression: " + name, ErrorCodes::BAD_ARGUMENTS); throw Exception("Element of set in IN or VALUES is not a constant expression: " + name, ErrorCodes::BAD_ARGUMENTS);
return result_column[0]; return std::make_pair(result_column[0], result.type);
} }
...@@ -53,7 +54,7 @@ ASTPtr evaluateConstantExpressionAsLiteral(ASTPtr & node, const Context & contex ...@@ -53,7 +54,7 @@ ASTPtr evaluateConstantExpressionAsLiteral(ASTPtr & node, const Context & contex
return node; return node;
return std::make_shared<ASTLiteral>(node->range, return std::make_shared<ASTLiteral>(node->range,
evaluateConstantExpression(node, context)); evaluateConstantExpression(node, context).first);
} }
......
...@@ -399,7 +399,7 @@ bool PKCondition::isPrimaryKeyPossiblyWrappedByMonotonicFunctionsImpl( ...@@ -399,7 +399,7 @@ bool PKCondition::isPrimaryKeyPossiblyWrappedByMonotonicFunctionsImpl(
static void castValueToType(const DataTypePtr & desired_type, Field & src_value, const DataTypePtr & src_type, const ASTPtr & node) static void castValueToType(const DataTypePtr & desired_type, Field & src_value, const DataTypePtr & src_type, const ASTPtr & node)
{ {
if (desired_type->getName() == src_type->getName()) if (desired_type->equals(*src_type))
return; return;
try try
......
DROP TABLE IF EXISTS test.group_uniq_array_int; DROP TABLE IF EXISTS test.group_uniq_arr_int;
CREATE TABLE test.group_uniq_arr_int ENGINE = Memory AS CREATE TABLE test.group_uniq_arr_int ENGINE = Memory AS
SELECT g as id, if(c == 0, [v], if(c == 1, emptyArrayInt64(), [v, v])) as v FROM SELECT g as id, if(c == 0, [v], if(c == 1, emptyArrayInt64(), [v, v])) as v FROM
(SELECT intDiv(number%1000000, 100) as v, intDiv(number%100, 10) as g, number%10 as c FROM system.numbers WHERE c < 3 LIMIT 10000000); (SELECT intDiv(number%1000000, 100) as v, intDiv(number%100, 10) as g, number%10 as c FROM system.numbers WHERE c < 3 LIMIT 10000000);
......
0 200
1 100
0 200 nan
1 100 nan
0 200 nan
1 100 nan
2 200 101
0 200 nan ['---']
1 100 nan ['---']
2 200 101 ['---']
0 200 nan ['---']
1 100 nan ['---']
2 200 101 ['---']
3 200 102 ['igua']
0 200 nan ['---']
1 100 nan ['---']
2 200 101 ['---']
3 200 102 ['igua']
---
---
1
0
DROP TABLE IF EXISTS test.agg_func_col;
CREATE TABLE test.agg_func_col (p Date, k UInt8, d AggregateFunction(sum, UInt64) DEFAULT arrayReduce('sumState', [toUInt64(200)])) ENGINE = AggregatingMergeTree(p, k, 1);
INSERT INTO test.agg_func_col (k) VALUES (0);
INSERT INTO test.agg_func_col SELECT 1 AS k, arrayReduce('sumState', [toUInt64(100)]) AS d;
SELECT k, sumMerge(d) FROM test.agg_func_col GROUP BY k ORDER BY k;
SELECT '';
ALTER TABLE test.agg_func_col ADD COLUMN af_avg1 AggregateFunction(avg, UInt8);
SELECT k, sumMerge(d), avgMerge(af_avg1) FROM test.agg_func_col GROUP BY k ORDER BY k;
SELECT '';
INSERT INTO test.agg_func_col (k, af_avg1) VALUES (2, arrayReduce('avgState', [101]));
SELECT k, sumMerge(d), avgMerge(af_avg1) FROM test.agg_func_col GROUP BY k ORDER BY k;
SELECT '';
ALTER TABLE test.agg_func_col ADD COLUMN af_gua AggregateFunction(groupUniqArray, String) DEFAULT arrayReduce('groupUniqArrayState', ['---', '---']);
SELECT k, sumMerge(d), avgMerge(af_avg1), groupUniqArrayMerge(af_gua) FROM test.agg_func_col GROUP BY k ORDER BY k;
SELECT '';
INSERT INTO test.agg_func_col (k, af_avg1, af_gua) VALUES (3, arrayReduce('avgState', [102, 102]), arrayReduce('groupUniqArrayState', ['igua', 'igua']));
SELECT k, sumMerge(d), avgMerge(af_avg1), groupUniqArrayMerge(af_gua) FROM test.agg_func_col GROUP BY k ORDER BY k;
OPTIMIZE TABLE test.agg_func_col;
SELECT '';
SELECT k, sumMerge(d), avgMerge(af_avg1), groupUniqArrayMerge(af_gua) FROM test.agg_func_col GROUP BY k ORDER BY k;
DROP TABLE IF EXISTS test.agg_func_col;
SELECT '';
SELECT arrayReduce('groupUniqArrayIf', [CAST('---' AS Nullable(String)), CAST('---' AS Nullable(String))], [1, 1])[1];
SELECT arrayReduce('groupUniqArrayMerge', [arrayReduce('groupUniqArrayState', [CAST('---' AS Nullable(String)), CAST('---' AS Nullable(String))])])[1];
SELECT '';
SELECT arrayReduce('avgState', [0]) IN (arrayReduce('avgState', [0, 1]), arrayReduce('avgState', [0]));
SELECT arrayReduce('avgState', [0]) IN (arrayReduce('avgState', [0, 1]), arrayReduce('avgState', [1]));
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册