提交 2efae5a5 编写于 作者: A Andrey Mironov

dbms: positionCaseInsensitiveUTF8 and some fixes (out of bounds access etc.)...

dbms: positionCaseInsensitiveUTF8 and some fixes (out of bounds access etc.) for positionCaseInsensitive. [#METR-16752]
上级 f04da146
......@@ -15,10 +15,12 @@
#include <DB/Functions/IFunction.h>
#include <re2/re2.h>
#include <re2/stringpiece.h>
#include <Poco/UTF8Encoding.h>
#include <mutex>
#include <stack>
#include <statdaemons/ext/range.hpp>
#include <Poco/Unicode.h>
namespace DB
......@@ -189,11 +191,10 @@ struct PositionCaseInsensitiveImpl
cachel = _mm_srli_si128(cachel, 1);
cacheu = _mm_srli_si128(cacheu, 1);
cachel = _mm_insert_epi8(cachel, std::tolower(*needle_pos), n - 1);
cacheu = _mm_insert_epi8(cacheu, std::toupper(*needle_pos), n - 1);
if (needle_pos != needle_end)
{
cachel = _mm_insert_epi8(cachel, std::tolower(*needle_pos), n - 1);
cacheu = _mm_insert_epi8(cacheu, std::toupper(*needle_pos), n - 1);
cachemask |= 1 << i;
++needle_pos;
}
......@@ -204,7 +205,7 @@ struct PositionCaseInsensitiveImpl
return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n;
};
const auto find_ci = [&] (const UInt8 * haystack, const UInt8 * const haystack_end) -> const UInt8 * {
const auto find_ci = [&] (const UInt8 * haystack, const UInt8 * const haystack_end) {
if (needle_begin == needle_end)
return haystack;
......@@ -260,7 +261,7 @@ struct PositionCaseInsensitiveImpl
}
if (haystack == haystack_end)
return haystack;
return haystack_end;
if (*haystack == l || *haystack == u)
{
......@@ -332,18 +333,248 @@ struct PositionCaseInsensitiveUTF8Impl
const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets, const std::string & needle,
PODArray<UInt64> & res)
{
throw Exception{
"Not yet implemented",
ErrorCodes::NOT_IMPLEMENTED
using UTF8SequenceBuffer = UInt8[6];
/// returns UTF-8 code point sequence length judging by it's first octet
const auto utf8_seq_length = [] (const UInt8 first_octet) {
if (first_octet < 0x80u)
return 1ul;
const std::size_t bits = 8;
const auto first_zero = _bit_scan_reverse(static_cast<UInt8>(~first_octet));
return bits - 1 - first_zero;
};
static const Poco::UTF8Encoding utf8;
UTF8SequenceBuffer l_seq, u_seq;
const auto first_u32 = utf8.convert(reinterpret_cast<const UInt8 *>(needle.data()));
const auto first_l_u32 = Poco::Unicode::toLower(first_u32);
const auto first_u_u32 = Poco::Unicode::toUpper(first_u32);
/// lower and uppercase variants of the first octet of the first character in `needle`
utf8.convert(first_l_u32, l_seq, sizeof(l_seq));
const auto l = l_seq[0];
utf8.convert(first_u_u32, u_seq, sizeof(u_seq));
const auto u = u_seq[0];
/// for detecting leftmost position of the first symbol
const auto patl = _mm_set1_epi8(l);
const auto patu = _mm_set1_epi8(u);
/// lower and uppercase vectors of first 16 octets of `needle`
auto cachel = _mm_setzero_si128();
auto cacheu = _mm_setzero_si128();
int cachemask = 0;
std::size_t cache_valid_len{};
std::size_t cache_actual_len{};
const auto n = sizeof(cachel);
const auto needle_begin = needle.data();
const auto needle_end = needle_begin + needle.size();
auto needle_pos = needle_begin;
const auto utf8_sync_forward = [] (const UInt8 * & s) {
const UInt8 continuation_octet_mask = 0b11000000u;
while ((*s & continuation_octet_mask) == continuation_octet_mask)
++s;
};
for (std::size_t i = 0; i < n;)
{
if (needle_pos == needle_end)
{
cachel = _mm_srli_si128(cachel, 1);
cacheu = _mm_srli_si128(cacheu, 1);
++i;
continue;
}
const auto src_len = utf8_seq_length(static_cast<UInt8>(*needle_pos));
const auto c_u32 = utf8.convert(reinterpret_cast<const UInt8 *>(needle_pos));
const auto c_l_u32 = Poco::Unicode::toLower(c_u32);
const auto c_u_u32 = Poco::Unicode::toUpper(c_u32);
const auto dst_l_len = static_cast<UInt8>(utf8.convert(c_l_u32, l_seq, sizeof(l_seq)));
const auto dst_u_len = static_cast<UInt8>(utf8.convert(c_u_u32, u_seq, sizeof(u_seq)));
/// @note Unicode standard states it is a rare but possible occasion
if (!(dst_l_len == dst_u_len && dst_u_len == src_len))
throw Exception{
"UTF8 sequences with different lowercase and uppercase lengths are not supported",
ErrorCodes::UNSUPPORTED_PARAMETER
};
cache_actual_len += src_len;
if (cache_actual_len < n)
cache_valid_len += src_len;
for (std::size_t j = 0; j < src_len && i < n; ++j, ++i)
{
cachel = _mm_srli_si128(cachel, 1);
cacheu = _mm_srli_si128(cacheu, 1);
if (needle_pos != needle_end)
{
cachel = _mm_insert_epi8(cachel, l_seq[j], n - 1);
cacheu = _mm_insert_epi8(cacheu, u_seq[j], n - 1);
cachemask |= 1 << i;
++needle_pos;
}
}
}
const auto page_size = getpagesize();
const auto page_safe = [&] (const void * const ptr) {
return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n;
};
const auto find_ci = [&] (const UInt8 * haystack, const UInt8 * const haystack_end) {
if (needle_begin == needle_end)
return haystack;
while (haystack < haystack_end)
{
if (haystack + n <= haystack_end && page_safe(haystack))
{
const auto v_haystack = _mm_loadu_si128(reinterpret_cast<const __m128i *>(haystack));
const auto v_against_l = _mm_cmpeq_epi8(v_haystack, patl);
const auto v_against_u = _mm_cmpeq_epi8(v_haystack, patu);
const auto v_against_l_or_u = _mm_or_si128(v_against_l, v_against_u);
const auto mask = _mm_movemask_epi8(v_against_l_or_u);
if (mask == 0)
{
haystack += n;
utf8_sync_forward(haystack);
continue;
}
const auto offset = _bit_scan_forward(mask);
haystack += offset;
if (haystack < haystack_end && haystack + n <= haystack_end && page_safe(haystack))
{
const auto v_haystack = _mm_loadu_si128(reinterpret_cast<const __m128i *>(haystack));
const auto v_against_l = _mm_cmpeq_epi8(v_haystack, cachel);
const auto v_against_u = _mm_cmpeq_epi8(v_haystack, cacheu);
const auto v_against_l_or_u = _mm_or_si128(v_against_l, v_against_u);
const auto mask = _mm_movemask_epi8(v_against_l_or_u);
if (0xffff == cachemask)
{
if (mask == cachemask)
{
auto s1 = haystack + cache_valid_len;
auto s2 = needle_begin + cache_valid_len;
while (s1 < haystack_end && s2 < needle_end &&
Poco::Unicode::toLower(utf8.convert(s1)) ==
Poco::Unicode::toLower(utf8.convert(reinterpret_cast<const UInt8 *>(s2))))
{
/// @note assuming sequences for lowercase and uppercase have exact same length
const auto len = utf8_seq_length(*s1);
s1 += len, s2 += len;
}
if (s2 == needle_end)
return haystack;
}
}
else if ((mask & cachemask) == cachemask)
return haystack;
/// first octet was ok, but not the first 16, move to start of next sequence and reapply
haystack += utf8_seq_length(*haystack);
continue;
}
}
if (haystack == haystack_end)
return haystack_end;
if (*haystack == l || *haystack == u)
{
auto s1 = haystack;
auto s2 = needle_begin;
while (s1 < haystack_end && s2 < needle_end &&
Poco::Unicode::toLower(utf8.convert(s1)) ==
Poco::Unicode::toLower(utf8.convert(reinterpret_cast<const UInt8 *>(s2))))
{
const auto len = utf8_seq_length(*s1);
s1 += len, s2 += len;
}
if (s2 == needle_end)
return haystack;
}
/// advance to the start of the next sequence
haystack += utf8_seq_length(*haystack);
}
return haystack_end;
};
const UInt8 * begin = &data[0];
const UInt8 * pos = begin;
const UInt8 * end = pos + data.size();
/// Текущий индекс в массиве строк.
size_t i = 0;
/// Искать будем следующее вхождение сразу во всех строках.
while (pos < end && end != (pos = find_ci(pos, end)))
{
/// Определим, к какому индексу оно относится.
while (begin + offsets[i] < pos)
{
res[i] = 0;
++i;
}
/// Проверяем, что вхождение не переходит через границы строк.
if (pos + needle.size() < begin + offsets[i])
res[i] = (i != 0) ? pos - begin - offsets[i - 1] + 1 : (pos - begin + 1);
else
res[i] = 0;
pos = begin + offsets[i];
++i;
}
memset(&res[i], 0, (res.size() - i) * sizeof(res[0]));
}
static void constant(const std::string & data, const std::string & needle, UInt64 & res)
static void constant(std::string data, std::string needle, UInt64 & res)
{
throw Exception{
"Not yet implemented",
ErrorCodes::NOT_IMPLEMENTED
};
static const Poco::UTF8Encoding utf8;
auto data_pos = reinterpret_cast<UInt8 *>(&data[0]);
const auto data_end = data_pos + data.size();
while (data_pos < data_end)
{
const auto len = utf8.convert(Poco::Unicode::toLower(utf8.convert(data_pos)), data_pos, data_end - data_pos);
data_pos += len;
}
auto needle_pos = reinterpret_cast<UInt8 *>(&needle[0]);
const auto needle_end = needle_pos + needle.size();
while (needle_pos < needle_end)
{
const auto len = utf8.convert(Poco::Unicode::toLower(utf8.convert(needle_pos)), needle_pos, needle_end - needle_pos);
needle_pos += len;
}
res = data.find(needle);
if (res == std::string::npos)
res = 0;
else
++res;
}
};
......@@ -1394,7 +1625,7 @@ public:
struct NamePosition { static constexpr auto name = "position"; };
struct NamePositionUTF8 { static constexpr auto name = "positionUTF8"; };
struct NamePositionCaseInsensitive { static constexpr auto name = "positionCaseInsensitive"; };
struct NamePositionCaseInsenseitiveUTF8 { static constexpr auto name = "positionCaseInsensitiveUTF8"; };
struct NamePositionCaseInsensitiveUTF8 { static constexpr auto name = "positionCaseInsensitiveUTF8"; };
struct NameMatch { static constexpr auto name = "match"; };
struct NameLike { static constexpr auto name = "like"; };
struct NameNotLike { static constexpr auto name = "notLike"; };
......@@ -1407,7 +1638,7 @@ struct NameReplaceRegexpAll { static constexpr auto name = "replaceRegexpAll"
typedef FunctionsStringSearch<PositionImpl, NamePosition> FunctionPosition;
typedef FunctionsStringSearch<PositionUTF8Impl, NamePositionUTF8> FunctionPositionUTF8;
typedef FunctionsStringSearch<PositionCaseInsensitiveImpl, NamePositionCaseInsensitive> FunctionPositionCaseInsensitive;
typedef FunctionsStringSearch<PositionCaseInsensitiveUTF8Impl, NamePositionCaseInsenseitiveUTF8> FunctionPositionCaseInsensitiveUTF8;
typedef FunctionsStringSearch<PositionCaseInsensitiveUTF8Impl, NamePositionCaseInsensitiveUTF8> FunctionPositionCaseInsensitiveUTF8;
typedef FunctionsStringSearch<MatchImpl<false>, NameMatch> FunctionMatch;
typedef FunctionsStringSearch<MatchImpl<true>, NameLike> FunctionLike;
typedef FunctionsStringSearch<MatchImpl<true, true>, NameNotLike> FunctionNotLike;
......
......@@ -13,8 +13,7 @@ void registerFunctionsStringSearch(FunctionFactory & factory)
factory.registerFunction<FunctionPosition>();
factory.registerFunction<FunctionPositionUTF8>();
factory.registerFunction<FunctionPositionCaseInsensitive>();
/// @todo implement
// factory.registerFunction<FunctionPositionCaseInsensitiveUTF8>();
factory.registerFunction<FunctionPositionCaseInsensitiveUTF8>();
factory.registerFunction<FunctionMatch>();
factory.registerFunction<FunctionLike>();
factory.registerFunction<FunctionNotLike>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册