diff --git a/ding/worker/buffer/__init__.py b/ding/worker/buffer/__init__.py index 8cdf53212fddafb7e8fcb4fb1dd41c2d8f95bccf..44055b6932f468b6fe6195cf503aa2983d487061 100644 --- a/ding/worker/buffer/__init__.py +++ b/ding/worker/buffer/__init__.py @@ -1,4 +1,3 @@ -from .buffer import Buffer -from .naive_buffer import NaiveBuffer +from .buffer import Buffer, RateLimit from .storage import Storage from .memory_storage import MemoryStorage diff --git a/ding/worker/buffer/buffer.py b/ding/worker/buffer/buffer.py index 2e7eebfd123108928e27e1eb3857327767b664ee..693b053678687db1cf960f1ce36fd3a1fba5fa0e 100644 --- a/ding/worker/buffer/buffer.py +++ b/ding/worker/buffer/buffer.py @@ -1,22 +1,86 @@ -from abc import abstractmethod -from typing import Any, List - +from typing import Any, Callable, List from ding.worker.buffer.storage import Storage +def apply_middleware(func_name: str): + + def wrap_func(f: Callable): + + def _apply_middleware(buffer, *args): + for func in buffer.middlewares: + done, *args = func(buffer, func_name, *args) + if done: + return args + return f(buffer, *args) + + return _apply_middleware + + return wrap_func + + class Buffer: - def __init__(self, storage: Storage) -> None: + def __init__(self, storage: Storage, **kwargs) -> None: self.storage = storage + self.middlewares = [] - @abstractmethod + @apply_middleware("push") def push(self, data: Any) -> None: - raise NotImplementedError + self.storage.append(data) - @abstractmethod + @apply_middleware("sample") def sample(self, size: int) -> List[Any]: - raise NotImplementedError + return self.storage.sample(size) - @abstractmethod + @apply_middleware("clear") def clear(self) -> None: - raise NotImplementedError + self.storage.clear() + + def use(self, func: Callable) -> "Buffer": + r""" + Overview: + Use algorithm middlewares to modify the behavior of the buffer. + Every middleware should be a callable function, it will receive three argument parts, including: + 1. The buffer instance, you can use this instance to visit every thing of the buffer, including the storage. + 2. The functions called by the user, there are three methods named `push`, `sample` and `clear`, so you can use these function name to decide which action to choose. + 3. The remaining arguments passed by the user to the original function, will be passed in *args. + + Each middleware handler should return two parts of the value, including: + 1. The first value is `done` (True or False), if done==True, the middleware chain will stop immediately, no more middlewares will be executed during this execution + 2. The remaining values, will be passed to the next middleware or the default function in the buffer. + Arguments: + - func (:obj:`Callable`): the middleware handler + """ + self.middlewares.append(func) + return self + + +class RateLimit: + r""" + Add rate limit threshold to push function + """ + + def __init__(self, max_rate: int = float("inf"), window_seconds: int = 30) -> None: + self.max_rate = max_rate + self.window_seconds = window_seconds + self.buffered = [] + + def handler(self) -> Callable: + + def _handler(buffer: Buffer, action: str, *args): + if action == "push": + return self.push(*args) + return args + + return _handler + + def push(self, data) -> None: + import time + current = time.time() + # Cut off stale records + self.buffered = [t for t in self.buffered if t > current - self.window_seconds] + if len(self.buffered) < self.max_rate: + self.buffered.append(current) + return False, data + else: + return True, None diff --git a/ding/worker/buffer/naive_buffer.py b/ding/worker/buffer/naive_buffer.py deleted file mode 100644 index 0a55a439310b3f509b47ea31b8c713f559314f85..0000000000000000000000000000000000000000 --- a/ding/worker/buffer/naive_buffer.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Any, List -from ding.worker.buffer import Buffer -from ding.worker.buffer.storage import Storage - - -class NaiveBuffer(Buffer): - - def __init__(self, storage: Storage, **kwargs) -> None: - super().__init__(storage, **kwargs) - - def push(self, data: Any) -> None: - self.storage.append(data) - - def sample(self, size: int) -> List[Any]: - return self.storage.sample(size) - - def clear(self) -> None: - self.storage.clear() diff --git a/ding/worker/buffer/tests/test_buffer.py b/ding/worker/buffer/tests/test_buffer.py index 10675cae5ad0000bc74558f086e3297c1c953ee5..f856accad1178b8f1140b3f489e86324583c75c6 100644 --- a/ding/worker/buffer/tests/test_buffer.py +++ b/ding/worker/buffer/tests/test_buffer.py @@ -1,14 +1,26 @@ import pytest -from ding.worker.buffer import NaiveBuffer +from ding.worker.buffer import Buffer from ding.worker.buffer import MemoryStorage +from ding.worker.buffer.buffer import RateLimit @pytest.mark.unittest def test_naive_push_sample(): storage = MemoryStorage(maxlen=10) - naive_buffer = NaiveBuffer(storage) + buffer = Buffer(storage) for i in range(20): - naive_buffer.push(i) + buffer.push(i) assert storage.count() == 10 - assert len(set(naive_buffer.sample(10))) == 10 - assert 0 not in naive_buffer.sample(10) + assert len(set(buffer.sample(10))) == 10 + assert 0 not in buffer.sample(10) + + +@pytest.mark.unittest +def test_rate_limit_push_sample(): + storage = MemoryStorage(maxlen=10) + ratelimit = RateLimit(max_rate=5) + buffer = Buffer(storage).use(ratelimit.handler()) + for i in range(10): + buffer.push(i) + assert storage.count() == 5 + assert 5 not in buffer.sample(5)