提交 b635808e 编写于 作者: A Alexey Milovidov

Added a patch from Андрей Иванов-Чакчир: function dictGetKeys [#CLICKHOUSE-2842].

上级 f46f3f9e
......@@ -121,6 +121,11 @@ public:
void has(const PaddedPODArray<Key> & ids, PaddedPODArray<UInt8> & out) const override;
void getKeys(const std::string & attribute_name,
std::vector<std::vector<Key>> &out, const std::string & value_of_target_attr) const;
void splitString(const std::string & str, std::vector<std::string> & strvector) const;
private:
template <typename Value> using ContainerType = PaddedPODArray<Value>;
template <typename Value> using ContainerPtrType = std::unique_ptr<ContainerType<Value>>;
......
......@@ -119,6 +119,10 @@ public:
void has(const PaddedPODArray<Key> & ids, PaddedPODArray<UInt8> & out) const override;
void getKeys(const std::string & attribute_name, std::vector<std::vector<Key>> &out, const std::string & value_of_target_attr) const;
void splitString(const std::string & str, std::vector<std::string> & strvector) const;
private:
template <typename Value> using CollectionType = HashMap<UInt64, Value>;
template <typename Value> using CollectionPtrType = std::unique_ptr<CollectionType<Value>>;
......
......@@ -2410,5 +2410,123 @@ private:
const ExternalDictionaries & dictionaries;
};
class FunctionDictGetKeys final : public IFunction
{
public:
static constexpr auto name = "dictGetKeys";
static FunctionPtr create(const Context & context)
{
return std::make_shared<FunctionDictGetKeys> (context.getExternalDictionaries());
}
FunctionDictGetKeys (const ExternalDictionaries & dictionaries) : dictionaries(dictionaries) {}
String getName() const override { return name; }
private:
size_t getNumberOfArguments() const override { return 3; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (arguments.size() != 3)
throw Exception{
"Number of arguments for function " + getName() + " doesn't match: passed "
+ toString(arguments.size()) + ", should be 3.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
if (!typeid_cast<const DataTypeString *>(arguments[0].get()))
{
throw Exception{
"Illegal type " + arguments[0]->getName() + " of first argument of function " + getName()
+ ", expected a string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
if (!typeid_cast<const DataTypeString *>(arguments[1].get()))
{
throw Exception{
"Illegal type " + arguments[1]->getName() + " of second argument of function " + getName()
+ ", expected a string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
if (!typeid_cast<const DataTypeString *>(arguments[2].get()))
{
throw Exception{
"Illegal type " + arguments[2]->getName() + " of third argument of function " + getName()
+ ", expected a string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt64>());
}
void executeImpl(Block & block, const ColumnNumbers & arguments, const size_t result) override
{
const auto dict_name_col = typeid_cast<const ColumnConst<String> *>(block.getByPosition(arguments[0]).column.get());
if (!dict_name_col)
throw Exception{
"First argument of function " + getName() + " must be a constant string",
ErrorCodes::ILLEGAL_COLUMN};
auto dict = dictionaries.getDictionary(dict_name_col->getData());
const auto dict_ptr = dict.get();
if (!executeDispatch<FlatDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatch<HashedDictionary>(block, arguments, result, dict_ptr))
throw Exception{
"Unsupported dictionary type " + dict_ptr->getTypeName(),
ErrorCodes::UNKNOWN_TYPE};
}
template <typename DictionaryType>
bool executeDispatch(
Block & block, const ColumnNumbers & arguments, const size_t result, const IDictionaryBase * const dictionary)
{
auto dict = typeid_cast<const DictionaryType *>(dictionary);
if (!dict)
return false;
const auto get_keys_vector = [&] (PaddedPODArray<UInt64> & out, PaddedPODArray<UInt64> & offsets) {
auto size = out.size();
size = 1;
std::vector<std::vector<IDictionary::Key>> outerIds(size);
const auto attr_name_col = typeid_cast<const ColumnConst<String> *>(block.getByPosition(arguments[1]).column.get());
const auto & attr_name = attr_name_col->getData();
const auto value_col = typeid_cast<const ColumnConst<String> *>(block.getByPosition(arguments[2]).column.get());
const auto & searched_value = value_col->getData();
dict->getKeys(attr_name, outerIds, searched_value);
out.reserve(outerIds[0].size());
offsets.resize(size);
for (const auto i : ext::range(0, size))
{
const auto & ids = outerIds[i];
out.insert_assume_reserved(std::begin(ids), std::end(ids));
offsets[i] = out.size();
}
};
const auto backend = std::make_shared<ColumnUInt64>();
const auto array = std::make_shared<ColumnArray>(backend);
get_keys_vector(backend->getData(), array->getOffsets());
block.getByPosition(result).column = std::make_shared<ColumnConstArray>(
1,(*array)[0].get<Array>(),
std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt64>()));
return true;
}
const ExternalDictionaries & dictionaries;
};
};
......@@ -468,4 +468,47 @@ void FlatDictionary::has(const Attribute & attribute, const PaddedPODArray<Key>
query_count.fetch_add(ids_count, std::memory_order_relaxed);
}
void FlatDictionary::getKeys(const std::string & attribute_name,
std::vector<std::vector<IDictionary::Key>> & out, const std::string & value_of_target_attr) const
{
const auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
const auto & attr = *std::get<ContainerPtrType<StringRef>>(attribute.arrays);
std::vector<std::string> strvector;
splitString(value_of_target_attr, strvector);
std::vector<int> idarray;
for (auto it = attr.begin(); it != attr.end(); it++)
{
for(auto str_it = strvector.begin(); str_it != strvector.end(); ++str_it)
{
if(*it == *str_it)
{
idarray.push_back(std::distance(attr.begin(),it));
out[0].push_back(std::distance(attr.begin(),it));
}
}
}
}
void FlatDictionary::splitString(const std::string & str, std::vector<std::string> & strvector) const
{
std::string delimiter("|");
size_t prev = 0;
size_t next;
size_t delim_size = delimiter.length();
while ((next = str.find(delimiter, prev)) != std::string::npos)
{
strvector.push_back(str.substr(prev, next - prev));
prev = next + delim_size;
}
strvector.push_back(str.substr(prev));
}
}
......@@ -420,4 +420,49 @@ void HashedDictionary::has(const Attribute & attribute, const PaddedPODArray<Key
query_count.fetch_add(rows, std::memory_order_relaxed);
}
void HashedDictionary::getKeys(const std::string & attribute_name,
std::vector<std::vector<IDictionary::Key>> & out, const std::string & value_of_target_attr) const
{
const auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
const auto & attr = *std::get<CollectionPtrType<StringRef>>(attribute.maps);
std::vector<std::string> strvector;
splitString(value_of_target_attr, strvector);
std::vector<int> idarray;
for (auto it = attr.begin(); it != attr.end(); ++it)
{
for(auto str_it = strvector.begin(); str_it != strvector.end(); ++str_it)
{
if(it->second == *str_it)
{
idarray.push_back(it->first);
out[0].push_back(it->first);
}
}
}
}
void HashedDictionary::splitString(const std::string & str, std::vector<std::string> & strvector) const
{
std::string delimiter("|");
size_t prev = 0;
size_t next;
size_t delim_size = delimiter.length();
while ((next = str.find(delimiter, prev)) != std::string::npos)
{
strvector.push_back(str.substr(prev, next - prev));
prev = next + delim_size;
}
strvector.push_back(str.substr(prev));
}
}
......@@ -38,6 +38,7 @@ void registerFunctionsDictionaries(FunctionFactory & factory)
factory.registerFunction<FunctionDictGetString>();
factory.registerFunction<FunctionDictGetHierarchy>();
factory.registerFunction<FunctionDictIsIn>();
factory.registerFunction<FunctionDictGetKeys>();
factory.registerFunction<FunctionDictGetUInt8OrDefault>();
factory.registerFunction<FunctionDictGetUInt16OrDefault>();
factory.registerFunction<FunctionDictGetUInt32OrDefault>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册