diff --git a/dbms/include/DB/Functions/FunctionsStringSearch.h b/dbms/include/DB/Functions/FunctionsStringSearch.h index 1b9aea0630e66dfa79285001b197ed3609c29319..ad2e45f3cc9c8a9e481afa1b1e3c77d7aa0756c1 100644 --- a/dbms/include/DB/Functions/FunctionsStringSearch.h +++ b/dbms/include/DB/Functions/FunctionsStringSearch.h @@ -15,10 +15,12 @@ #include #include #include +#include #include #include #include +#include namespace DB @@ -189,11 +191,10 @@ struct PositionCaseInsensitiveImpl cachel = _mm_srli_si128(cachel, 1); cacheu = _mm_srli_si128(cacheu, 1); - cachel = _mm_insert_epi8(cachel, std::tolower(*needle_pos), n - 1); - cacheu = _mm_insert_epi8(cacheu, std::toupper(*needle_pos), n - 1); - if (needle_pos != needle_end) { + cachel = _mm_insert_epi8(cachel, std::tolower(*needle_pos), n - 1); + cacheu = _mm_insert_epi8(cacheu, std::toupper(*needle_pos), n - 1); cachemask |= 1 << i; ++needle_pos; } @@ -204,7 +205,7 @@ struct PositionCaseInsensitiveImpl return ((page_size - 1) & reinterpret_cast(ptr)) <= page_size - n; }; - const auto find_ci = [&] (const UInt8 * haystack, const UInt8 * const haystack_end) -> const UInt8 * { + const auto find_ci = [&] (const UInt8 * haystack, const UInt8 * const haystack_end) { if (needle_begin == needle_end) return haystack; @@ -260,7 +261,7 @@ struct PositionCaseInsensitiveImpl } if (haystack == haystack_end) - return haystack; + return haystack_end; if (*haystack == l || *haystack == u) { @@ -332,18 +333,248 @@ struct PositionCaseInsensitiveUTF8Impl const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets, const std::string & needle, PODArray & res) { - throw Exception{ - "Not yet implemented", - ErrorCodes::NOT_IMPLEMENTED + using UTF8SequenceBuffer = UInt8[6]; + + /// returns UTF-8 code point sequence length judging by it's first octet + const auto utf8_seq_length = [] (const UInt8 first_octet) { + if (first_octet < 0x80u) + return 1ul; + + const std::size_t bits = 8; + const auto first_zero = _bit_scan_reverse(static_cast(~first_octet)); + + return bits - 1 - first_zero; + }; + + static const Poco::UTF8Encoding utf8; + UTF8SequenceBuffer l_seq, u_seq; + + const auto first_u32 = utf8.convert(reinterpret_cast(needle.data())); + const auto first_l_u32 = Poco::Unicode::toLower(first_u32); + const auto first_u_u32 = Poco::Unicode::toUpper(first_u32); + + /// lower and uppercase variants of the first octet of the first character in `needle` + utf8.convert(first_l_u32, l_seq, sizeof(l_seq)); + const auto l = l_seq[0]; + utf8.convert(first_u_u32, u_seq, sizeof(u_seq)); + const auto u = u_seq[0]; + /// for detecting leftmost position of the first symbol + const auto patl = _mm_set1_epi8(l); + const auto patu = _mm_set1_epi8(u); + /// lower and uppercase vectors of first 16 octets of `needle` + auto cachel = _mm_setzero_si128(); + auto cacheu = _mm_setzero_si128(); + int cachemask = 0; + std::size_t cache_valid_len{}; + std::size_t cache_actual_len{}; + + const auto n = sizeof(cachel); + const auto needle_begin = needle.data(); + const auto needle_end = needle_begin + needle.size(); + auto needle_pos = needle_begin; + + const auto utf8_sync_forward = [] (const UInt8 * & s) { + const UInt8 continuation_octet_mask = 0b11000000u; + while ((*s & continuation_octet_mask) == continuation_octet_mask) + ++s; + }; + + for (std::size_t i = 0; i < n;) + { + if (needle_pos == needle_end) + { + cachel = _mm_srli_si128(cachel, 1); + cacheu = _mm_srli_si128(cacheu, 1); + ++i; + + continue; + } + + const auto src_len = utf8_seq_length(static_cast(*needle_pos)); + const auto c_u32 = utf8.convert(reinterpret_cast(needle_pos)); + + const auto c_l_u32 = Poco::Unicode::toLower(c_u32); + const auto c_u_u32 = Poco::Unicode::toUpper(c_u32); + + const auto dst_l_len = static_cast(utf8.convert(c_l_u32, l_seq, sizeof(l_seq))); + const auto dst_u_len = static_cast(utf8.convert(c_u_u32, u_seq, sizeof(u_seq))); + + /// @note Unicode standard states it is a rare but possible occasion + if (!(dst_l_len == dst_u_len && dst_u_len == src_len)) + throw Exception{ + "UTF8 sequences with different lowercase and uppercase lengths are not supported", + ErrorCodes::UNSUPPORTED_PARAMETER + }; + + cache_actual_len += src_len; + if (cache_actual_len < n) + cache_valid_len += src_len; + + for (std::size_t j = 0; j < src_len && i < n; ++j, ++i) + { + cachel = _mm_srli_si128(cachel, 1); + cacheu = _mm_srli_si128(cacheu, 1); + + if (needle_pos != needle_end) + { + cachel = _mm_insert_epi8(cachel, l_seq[j], n - 1); + cacheu = _mm_insert_epi8(cacheu, u_seq[j], n - 1); + + cachemask |= 1 << i; + ++needle_pos; + } + } + } + + const auto page_size = getpagesize(); + const auto page_safe = [&] (const void * const ptr) { + return ((page_size - 1) & reinterpret_cast(ptr)) <= page_size - n; + }; + + const auto find_ci = [&] (const UInt8 * haystack, const UInt8 * const haystack_end) { + if (needle_begin == needle_end) + return haystack; + + while (haystack < haystack_end) + { + if (haystack + n <= haystack_end && page_safe(haystack)) + { + const auto v_haystack = _mm_loadu_si128(reinterpret_cast(haystack)); + const auto v_against_l = _mm_cmpeq_epi8(v_haystack, patl); + const auto v_against_u = _mm_cmpeq_epi8(v_haystack, patu); + const auto v_against_l_or_u = _mm_or_si128(v_against_l, v_against_u); + + const auto mask = _mm_movemask_epi8(v_against_l_or_u); + + if (mask == 0) + { + haystack += n; + utf8_sync_forward(haystack); + continue; + } + + const auto offset = _bit_scan_forward(mask); + haystack += offset; + + if (haystack < haystack_end && haystack + n <= haystack_end && page_safe(haystack)) + { + const auto v_haystack = _mm_loadu_si128(reinterpret_cast(haystack)); + const auto v_against_l = _mm_cmpeq_epi8(v_haystack, cachel); + const auto v_against_u = _mm_cmpeq_epi8(v_haystack, cacheu); + const auto v_against_l_or_u = _mm_or_si128(v_against_l, v_against_u); + const auto mask = _mm_movemask_epi8(v_against_l_or_u); + + if (0xffff == cachemask) + { + if (mask == cachemask) + { + auto s1 = haystack + cache_valid_len; + auto s2 = needle_begin + cache_valid_len; + + while (s1 < haystack_end && s2 < needle_end && + Poco::Unicode::toLower(utf8.convert(s1)) == + Poco::Unicode::toLower(utf8.convert(reinterpret_cast(s2)))) + { + /// @note assuming sequences for lowercase and uppercase have exact same length + const auto len = utf8_seq_length(*s1); + s1 += len, s2 += len; + } + + if (s2 == needle_end) + return haystack; + } + } + else if ((mask & cachemask) == cachemask) + return haystack; + + /// first octet was ok, but not the first 16, move to start of next sequence and reapply + haystack += utf8_seq_length(*haystack); + continue; + } + } + + if (haystack == haystack_end) + return haystack_end; + + if (*haystack == l || *haystack == u) + { + auto s1 = haystack; + auto s2 = needle_begin; + + while (s1 < haystack_end && s2 < needle_end && + Poco::Unicode::toLower(utf8.convert(s1)) == + Poco::Unicode::toLower(utf8.convert(reinterpret_cast(s2)))) + { + const auto len = utf8_seq_length(*s1); + s1 += len, s2 += len; + } + + if (s2 == needle_end) + return haystack; + } + + /// advance to the start of the next sequence + haystack += utf8_seq_length(*haystack); + } + + return haystack_end; }; + + const UInt8 * begin = &data[0]; + const UInt8 * pos = begin; + const UInt8 * end = pos + data.size(); + + /// Текущий индекс в массиве строк. + size_t i = 0; + + /// Искать будем следующее вхождение сразу во всех строках. + while (pos < end && end != (pos = find_ci(pos, end))) + { + /// Определим, к какому индексу оно относится. + while (begin + offsets[i] < pos) + { + res[i] = 0; + ++i; + } + + /// Проверяем, что вхождение не переходит через границы строк. + if (pos + needle.size() < begin + offsets[i]) + res[i] = (i != 0) ? pos - begin - offsets[i - 1] + 1 : (pos - begin + 1); + else + res[i] = 0; + + pos = begin + offsets[i]; + ++i; + } + + memset(&res[i], 0, (res.size() - i) * sizeof(res[0])); } - static void constant(const std::string & data, const std::string & needle, UInt64 & res) + static void constant(std::string data, std::string needle, UInt64 & res) { - throw Exception{ - "Not yet implemented", - ErrorCodes::NOT_IMPLEMENTED - }; + static const Poco::UTF8Encoding utf8; + + auto data_pos = reinterpret_cast(&data[0]); + const auto data_end = data_pos + data.size(); + while (data_pos < data_end) + { + const auto len = utf8.convert(Poco::Unicode::toLower(utf8.convert(data_pos)), data_pos, data_end - data_pos); + data_pos += len; + } + + auto needle_pos = reinterpret_cast(&needle[0]); + const auto needle_end = needle_pos + needle.size(); + while (needle_pos < needle_end) + { + const auto len = utf8.convert(Poco::Unicode::toLower(utf8.convert(needle_pos)), needle_pos, needle_end - needle_pos); + needle_pos += len; + } + + res = data.find(needle); + if (res == std::string::npos) + res = 0; + else + ++res; } }; @@ -1394,7 +1625,7 @@ public: struct NamePosition { static constexpr auto name = "position"; }; struct NamePositionUTF8 { static constexpr auto name = "positionUTF8"; }; struct NamePositionCaseInsensitive { static constexpr auto name = "positionCaseInsensitive"; }; -struct NamePositionCaseInsenseitiveUTF8 { static constexpr auto name = "positionCaseInsensitiveUTF8"; }; +struct NamePositionCaseInsensitiveUTF8 { static constexpr auto name = "positionCaseInsensitiveUTF8"; }; struct NameMatch { static constexpr auto name = "match"; }; struct NameLike { static constexpr auto name = "like"; }; struct NameNotLike { static constexpr auto name = "notLike"; }; @@ -1407,7 +1638,7 @@ struct NameReplaceRegexpAll { static constexpr auto name = "replaceRegexpAll" typedef FunctionsStringSearch FunctionPosition; typedef FunctionsStringSearch FunctionPositionUTF8; typedef FunctionsStringSearch FunctionPositionCaseInsensitive; -typedef FunctionsStringSearch FunctionPositionCaseInsensitiveUTF8; +typedef FunctionsStringSearch FunctionPositionCaseInsensitiveUTF8; typedef FunctionsStringSearch, NameMatch> FunctionMatch; typedef FunctionsStringSearch, NameLike> FunctionLike; typedef FunctionsStringSearch, NameNotLike> FunctionNotLike; diff --git a/dbms/src/Functions/FunctionsStringSearch.cpp b/dbms/src/Functions/FunctionsStringSearch.cpp index e51bf6423102264032129373a8c053a3a6a0ea4f..b41db74b9fda0e322d4a2e050414d3afd83cc69a 100644 --- a/dbms/src/Functions/FunctionsStringSearch.cpp +++ b/dbms/src/Functions/FunctionsStringSearch.cpp @@ -13,8 +13,7 @@ void registerFunctionsStringSearch(FunctionFactory & factory) factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); -/// @todo implement -// factory.registerFunction(); + factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction();