AggregateFunctionSequenceMatch.h 17.1 KB
Newer Older
1 2
#pragma once

3 4 5 6
#include <AggregateFunctions/IAggregateFunction.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnsNumber.h>
7
#include <ext/range.h>
8
#include <Common/PODArray.h>
9
#include <Common/typeid_cast.h>
10 11
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
12 13 14 15 16 17 18
#include <bitset>
#include <stack>


namespace DB
{

19 20
namespace ErrorCodes
{
21 22 23 24
    extern const int TOO_SLOW;
    extern const int TOO_LESS_ARGUMENTS_FOR_FUNCTION;
    extern const int TOO_MUCH_ARGUMENTS_FOR_FUNCTION;
    extern const int SYNTAX_ERROR;
25
    extern const int BAD_ARGUMENTS;
A
Alexey Milovidov 已提交
26
    extern const int LOGICAL_ERROR;
27 28
}

29 30 31 32
/// helper type for comparing `std::pair`s using solely the .first member
template <template <typename> class Comparator>
struct ComparePairFirst final
{
33 34 35 36 37
    template <typename T1, typename T2>
    bool operator()(const std::pair<T1, T2> & lhs, const std::pair<T1, T2> & rhs) const
    {
        return Comparator<T1>{}(lhs.first, rhs.first);
    }
38 39 40 41
};

struct AggregateFunctionSequenceMatchData final
{
42 43 44 45 46 47 48 49 50
    static constexpr auto max_events = 32;

    using Timestamp = std::uint32_t;
    using Events = std::bitset<max_events>;
    using TimestampEvents = std::pair<Timestamp, Events>;
    using Comparator = ComparePairFirst<std::less>;

    bool sorted = true;
    static constexpr size_t bytes_in_arena = 64;
A
Alexey Milovidov 已提交
51
    PODArray<TimestampEvents, bytes_in_arena, AllocatorWithStackMemory<Allocator<false>, bytes_in_arena>> events_list;
52 53 54 55 56 57

    void add(const Timestamp timestamp, const Events & events)
    {
        /// store information exclusively for rows with at least one event
        if (events.any())
        {
A
Alexey Milovidov 已提交
58
            events_list.emplace_back(timestamp, events);
59 60 61 62 63 64
            sorted = false;
        }
    }

    void merge(const AggregateFunctionSequenceMatchData & other)
    {
A
Alexey Milovidov 已提交
65
        const auto size = events_list.size();
66

A
Alexey Milovidov 已提交
67
        events_list.insert(std::begin(other.events_list), std::end(other.events_list));
68 69 70

        /// either sort whole container or do so partially merging ranges afterwards
        if (!sorted && !other.sorted)
A
Alexey Milovidov 已提交
71
            std::sort(std::begin(events_list), std::end(events_list), Comparator{});
72 73
        else
        {
A
Alexey Milovidov 已提交
74
            const auto begin = std::begin(events_list);
75
            const auto middle = std::next(begin, size);
A
Alexey Milovidov 已提交
76
            const auto end = std::end(events_list);
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93

            if (!sorted)
                std::sort(begin, middle, Comparator{});

            if (!other.sorted)
                std::sort(middle, end, Comparator{});

            std::inplace_merge(begin, middle, end, Comparator{});
        }

        sorted = true;
    }

    void sort()
    {
        if (!sorted)
        {
A
Alexey Milovidov 已提交
94
            std::sort(std::begin(events_list), std::end(events_list), Comparator{});
95 96 97 98 99 100 101
            sorted = true;
        }
    }

    void serialize(WriteBuffer & buf) const
    {
        writeBinary(sorted, buf);
A
Alexey Milovidov 已提交
102
        writeBinary(events_list.size(), buf);
103

A
Alexey Milovidov 已提交
104
        for (const auto & events : events_list)
105 106 107 108 109 110 111 112 113 114
        {
            writeBinary(events.first, buf);
            writeBinary(events.second.to_ulong(), buf);
        }
    }

    void deserialize(ReadBuffer & buf)
    {
        readBinary(sorted, buf);

115
        size_t size;
116 117
        readBinary(size, buf);

A
Alexey Milovidov 已提交
118 119
        events_list.clear();
        events_list.reserve(size);
120

121
        for (size_t i = 0; i < size; ++i)
122 123 124 125 126 127 128
        {
            std::uint32_t timestamp;
            readBinary(timestamp, buf);

            UInt64 events;
            readBinary(events, buf);

A
Alexey Milovidov 已提交
129
            events_list.emplace_back(timestamp, Events{events});
130 131
        }
    }
132 133
};

134 135 136 137

/// Max number of iterations to match the pattern against a sequence, exception thrown when exceeded
constexpr auto sequence_match_max_iterations = 1000000;

138
class AggregateFunctionSequenceMatch : public IAggregateFunctionHelper<AggregateFunctionSequenceMatchData>
139 140
{
public:
141
    static bool sufficientArgs(const size_t arg_count) { return arg_count >= 3; }
142 143 144 145 146 147 148 149 150 151

    String getName() const override { return "sequenceMatch"; }

    DataTypePtr getReturnType() const override { return std::make_shared<DataTypeUInt8>(); }

    void setParameters(const Array & params) override
    {
        if (params.size() != 1)
            throw Exception{
                "Aggregate function " + getName() + " requires exactly one parameter.",
152
                ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
153 154 155 156 157 158 159 160 161 162 163

        pattern = params.front().safeGet<std::string>();
    }

    void setArguments(const DataTypes & arguments) override
    {
        arg_count = arguments.size();

        if (!sufficientArgs(arg_count))
            throw Exception{
                "Aggregate function " + getName() + " requires at least 3 arguments.",
164
                ErrorCodes::TOO_LESS_ARGUMENTS_FOR_FUNCTION};
165 166 167 168 169

        if (arg_count - 1 > Data::max_events)
            throw Exception{
                "Aggregate function " + getName() + " supports up to " +
                    std::to_string(Data::max_events) + " event arguments.",
170
                ErrorCodes::TOO_MUCH_ARGUMENTS_FOR_FUNCTION};
171 172 173 174 175 176

        const auto time_arg = arguments.front().get();
        if (!typeid_cast<const DataTypeDateTime *>(time_arg))
            throw Exception{
                "Illegal type " + time_arg->getName() + " of first argument of aggregate function " +
                    getName() + ", must be DateTime",
177
                ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
178 179 180 181 182 183 184 185

        for (const auto i : ext::range(1, arg_count))
        {
            const auto cond_arg = arguments[i].get();
            if (!typeid_cast<const DataTypeUInt8 *>(cond_arg))
                throw Exception{
                    "Illegal type " + cond_arg->getName() + " of argument " + toString(i + 1) +
                        " of aggregate function " + getName() + ", must be UInt8",
186
                    ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
        }

        parsePattern();
    }

    void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
    {
        const auto timestamp = static_cast<const ColumnUInt32 *>(columns[0])->getData()[row_num];

        Data::Events events;
        for (const auto i : ext::range(1, arg_count))
        {
            const auto event = static_cast<const ColumnUInt8 *>(columns[i])->getData()[row_num];
            events.set(i - 1, event);
        }

        data(place).add(timestamp, events);
    }

    void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
    {
        data(place).merge(data(rhs));
    }

    void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
    {
        data(place).serialize(buf);
    }

    void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
    {
        data(place).deserialize(buf);
    }

    void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
    {
        const_cast<Data &>(data(place)).sort();

        const auto & data_ref = data(place);

A
Alexey Milovidov 已提交
227 228
        const auto events_begin = std::begin(data_ref.events_list);
        const auto events_end = std::end(data_ref.events_list);
229 230 231 232 233 234 235 236 237 238 239
        auto events_it = events_begin;

        static_cast<ColumnUInt8 &>(to).getData().push_back(match(events_it, events_end));
    }

    static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena)
    {
        static_cast<const AggregateFunctionSequenceMatch &>(*that).add(place, columns, row_num, arena);
    }

    IAggregateFunction::AddFunc getAddressOfAddFunction() const override final { return &addFree; }
240

241
private:
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
    enum class PatternActionType
    {
        SpecificEvent,
        AnyEvent,
        KleeneStar,
        TimeLessOrEqual,
        TimeLess,
        TimeGreaterOrEqual,
        TimeGreater
    };

    struct PatternAction final
    {
        PatternActionType type;
        std::uint32_t extra;

        PatternAction() = default;
        PatternAction(const PatternActionType type, const std::uint32_t extra = 0) : type{type}, extra{extra} {}
    };

    static constexpr size_t bytes_on_stack = 64;
    using PatternActions = PODArray<PatternAction, bytes_on_stack, AllocatorWithStackMemory<Allocator<false>, bytes_on_stack>>;


    void parsePattern()
    {
        actions.clear();
        actions.emplace_back(PatternActionType::KleeneStar);

P
proller 已提交
271
        const char * pos = pattern.data();
272 273
        const char * begin = pos;
        const char * end = pos + pattern.size();
274

275
        auto throw_exception = [&](const std::string & msg)
A
Alexey Milovidov 已提交
276
        {
277 278
            throw Exception{
                msg + " '" + std::string(pos, end) + "' at position " + std::to_string(pos - begin),
A
Alexey Milovidov 已提交
279
                ErrorCodes::SYNTAX_ERROR};
280 281
        };

282
        auto match = [&pos, end](const char * str) mutable
283 284
        {
            size_t length = strlen(str);
285
            if (pos + length <= end && 0 == memcmp(pos, str, length))
286 287 288 289 290 291 292
            {
                pos += length;
                return true;
            }
            return false;
        };

293 294
        while (pos < end)
        {
295
            if (match("(?"))
296
            {
297
                if (match("t"))
298 299 300
                {
                    PatternActionType type;

301
                    if (match("<="))
302
                        type = PatternActionType::TimeLessOrEqual;
303
                    else if (match("<"))
304
                        type = PatternActionType::TimeLess;
305
                    else if (match(">="))
306
                        type = PatternActionType::TimeGreaterOrEqual;
307
                    else if (match(">"))
308 309 310 311
                        type = PatternActionType::TimeGreater;
                    else
                        throw_exception("Unknown time condition");

312 313 314 315
                    UInt64 duration = 0;
                    auto prev_pos = pos;
                    pos = tryReadIntText(duration, pos, end);
                    if (pos == prev_pos)
316 317 318 319 320 321 322
                        throw_exception("Could not parse number");

                    if (actions.back().type != PatternActionType::SpecificEvent &&
                        actions.back().type != PatternActionType::AnyEvent &&
                        actions.back().type != PatternActionType::KleeneStar)
                        throw Exception{
                            "Temporal condition should be preceeded by an event condition",
323
                            ErrorCodes::BAD_ARGUMENTS};
324

325
                    actions.emplace_back(type, duration);
326
                }
327
                else
328
                {
329 330 331 332 333 334
                    UInt64 event_number = 0;
                    auto prev_pos = pos;
                    pos = tryReadIntText(event_number, pos, end);
                    if (pos == prev_pos)
                        throw_exception("Could not parse number");

335 336 337
                    if (event_number > arg_count - 1)
                        throw Exception{
                            "Event number " + std::to_string(event_number) + " is out of range",
338
                            ErrorCodes::BAD_ARGUMENTS};
339 340 341 342

                    actions.emplace_back(PatternActionType::SpecificEvent, event_number - 1);
                }

343
                if (!match(")"))
344 345 346
                    throw_exception("Expected closing parenthesis, found");

            }
347
            else if (match(".*"))
348
                actions.emplace_back(PatternActionType::KleeneStar);
349
            else if (match("."))
350 351 352 353 354
                actions.emplace_back(PatternActionType::AnyEvent);
            else
                throw_exception("Could not parse pattern, unexpected starting symbol");
        }
    }
355

356
protected:
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
    template <typename T>
    bool match(T & events_it, const T events_end) const
    {
        const auto action_begin = std::begin(actions);
        const auto action_end = std::end(actions);
        auto action_it = action_begin;

        const auto events_begin = events_it;
        auto base_it = events_it;

        /// an iterator to action plus an iterator to row in events list plus timestamp at the start of sequence
        using backtrack_info = std::tuple<decltype(action_it), T, T>;
        std::stack<backtrack_info> back_stack;

        /// backtrack if possible
        const auto do_backtrack = [&] {
            while (!back_stack.empty())
            {
                auto & top = back_stack.top();

                action_it = std::get<0>(top);
                events_it = std::next(std::get<1>(top));
                base_it = std::get<2>(top);

                back_stack.pop();

                if (events_it != events_end)
                    return true;
            }

            return false;
        };

390
        size_t i = 0;
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
        while (action_it != action_end && events_it != events_end)
        {
            if (action_it->type == PatternActionType::SpecificEvent)
            {
                if (events_it->second.test(action_it->extra))
                {
                    /// move to the next action and events
                    base_it = events_it;
                    ++action_it, ++events_it;
                }
                else if (!do_backtrack())
                    /// backtracking failed, bail out
                    break;
            }
            else if (action_it->type == PatternActionType::AnyEvent)
            {
                base_it = events_it;
                ++action_it, ++events_it;
            }
            else if (action_it->type == PatternActionType::KleeneStar)
            {
                back_stack.emplace(action_it, events_it, base_it);
                base_it = events_it;
                ++action_it;
            }
            else if (action_it->type == PatternActionType::TimeLessOrEqual)
            {
                if (events_it->first - base_it->first <= action_it->extra)
                {
                    /// condition satisfied, move onto next action
                    back_stack.emplace(action_it, events_it, base_it);
                    base_it = events_it;
                    ++action_it;
                }
                else if (!do_backtrack())
                    break;
            }
            else if (action_it->type == PatternActionType::TimeLess)
            {
                if (events_it->first - base_it->first < action_it->extra)
                {
                    back_stack.emplace(action_it, events_it, base_it);
                    base_it = events_it;
                    ++action_it;
                }
                else if (!do_backtrack())
                    break;
            }
            else if (action_it->type == PatternActionType::TimeGreaterOrEqual)
            {
                if (events_it->first - base_it->first >= action_it->extra)
                {
                    back_stack.emplace(action_it, events_it, base_it);
                    base_it = events_it;
                    ++action_it;
                }
                else if (++events_it == events_end && !do_backtrack())
                    break;
            }
            else if (action_it->type == PatternActionType::TimeGreater)
            {
                if (events_it->first - base_it->first > action_it->extra)
                {
                    back_stack.emplace(action_it, events_it, base_it);
                    base_it = events_it;
                    ++action_it;
                }
                else if (++events_it == events_end && !do_backtrack())
                    break;
            }
            else
                throw Exception{
                    "Unknown PatternActionType",
A
Alexey Milovidov 已提交
464
                    ErrorCodes::LOGICAL_ERROR};
465 466 467 468

            if (++i > sequence_match_max_iterations)
                throw Exception{
                    "Pattern application proves too difficult, exceeding max iterations (" + toString(sequence_match_max_iterations) + ")",
A
Alexey Milovidov 已提交
469
                    ErrorCodes::TOO_SLOW};
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
        }

        /// if there are some actions remaining
        if (action_it != action_end)
        {
            /// match multiple empty strings at end
            while (action_it->type == PatternActionType::KleeneStar ||
                   action_it->type == PatternActionType::TimeLessOrEqual ||
                   action_it->type == PatternActionType::TimeLess ||
                   (action_it->type == PatternActionType::TimeGreaterOrEqual && action_it->extra == 0))
                ++action_it;
        }

        if (events_it == events_begin)
            ++events_it;

        return action_it == action_end;
    }
488

489
private:
490
    std::string pattern;
491
    size_t arg_count;
492
    PatternActions actions;
493 494
};

495 496 497
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceMatch
{
public:
498
    String getName() const override { return "sequenceCount"; }
499

500
    DataTypePtr getReturnType() const override { return std::make_shared<DataTypeUInt64>(); }
501

502 503 504 505 506
    void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
    {
        const_cast<Data &>(data(place)).sort();
        static_cast<ColumnUInt64 &>(to).getData().push_back(count(place));
    }
507 508

private:
509 510 511
    UInt64 count(const ConstAggregateDataPtr & place) const
    {
        const auto & data_ref = data(place);
512

A
Alexey Milovidov 已提交
513 514
        const auto events_begin = std::begin(data_ref.events_list);
        const auto events_end = std::end(data_ref.events_list);
515
        auto events_it = events_begin;
516

517
        size_t count = 0;
518 519
        while (events_it != events_end && match(events_it, events_end))
            ++count;
520

521 522
        return count;
    }
523 524
};

525
}