提交 1afceefe 编写于 作者: X Xu Jingxin

Use ratelimit as middleware

上级 02a0e808
from .buffer import Buffer from .buffer import Buffer, RateLimit
from .naive_buffer import NaiveBuffer
from .storage import Storage from .storage import Storage
from .memory_storage import MemoryStorage from .memory_storage import MemoryStorage
from abc import abstractmethod from typing import Any, Callable, List
from typing import Any, List
from ding.worker.buffer.storage import Storage from ding.worker.buffer.storage import Storage
def apply_middleware(func_name: str):
def wrap_func(f: Callable):
def _apply_middleware(buffer, *args):
for func in buffer.middlewares:
done, *args = func(buffer, func_name, *args)
if done:
return args
return f(buffer, *args)
return _apply_middleware
return wrap_func
class Buffer: class Buffer:
def __init__(self, storage: Storage) -> None: def __init__(self, storage: Storage, **kwargs) -> None:
self.storage = storage self.storage = storage
self.middlewares = []
@abstractmethod @apply_middleware("push")
def push(self, data: Any) -> None: def push(self, data: Any) -> None:
raise NotImplementedError self.storage.append(data)
@abstractmethod @apply_middleware("sample")
def sample(self, size: int) -> List[Any]: def sample(self, size: int) -> List[Any]:
raise NotImplementedError return self.storage.sample(size)
@abstractmethod @apply_middleware("clear")
def clear(self) -> None: def clear(self) -> None:
raise NotImplementedError self.storage.clear()
def use(self, func: Callable) -> "Buffer":
r"""
Overview:
Use algorithm middlewares to modify the behavior of the buffer.
Every middleware should be a callable function, it will receive three argument parts, including:
1. The buffer instance, you can use this instance to visit every thing of the buffer, including the storage.
2. The functions called by the user, there are three methods named `push`, `sample` and `clear`, so you can use these function name to decide which action to choose.
3. The remaining arguments passed by the user to the original function, will be passed in *args.
Each middleware handler should return two parts of the value, including:
1. The first value is `done` (True or False), if done==True, the middleware chain will stop immediately, no more middlewares will be executed during this execution
2. The remaining values, will be passed to the next middleware or the default function in the buffer.
Arguments:
- func (:obj:`Callable`): the middleware handler
"""
self.middlewares.append(func)
return self
class RateLimit:
r"""
Add rate limit threshold to push function
"""
def __init__(self, max_rate: int = float("inf"), window_seconds: int = 30) -> None:
self.max_rate = max_rate
self.window_seconds = window_seconds
self.buffered = []
def handler(self) -> Callable:
def _handler(buffer: Buffer, action: str, *args):
if action == "push":
return self.push(*args)
return args
return _handler
def push(self, data) -> None:
import time
current = time.time()
# Cut off stale records
self.buffered = [t for t in self.buffered if t > current - self.window_seconds]
if len(self.buffered) < self.max_rate:
self.buffered.append(current)
return False, data
else:
return True, None
from typing import Any, List
from ding.worker.buffer import Buffer
from ding.worker.buffer.storage import Storage
class NaiveBuffer(Buffer):
def __init__(self, storage: Storage, **kwargs) -> None:
super().__init__(storage, **kwargs)
def push(self, data: Any) -> None:
self.storage.append(data)
def sample(self, size: int) -> List[Any]:
return self.storage.sample(size)
def clear(self) -> None:
self.storage.clear()
import pytest import pytest
from ding.worker.buffer import NaiveBuffer from ding.worker.buffer import Buffer
from ding.worker.buffer import MemoryStorage from ding.worker.buffer import MemoryStorage
from ding.worker.buffer.buffer import RateLimit
@pytest.mark.unittest @pytest.mark.unittest
def test_naive_push_sample(): def test_naive_push_sample():
storage = MemoryStorage(maxlen=10) storage = MemoryStorage(maxlen=10)
naive_buffer = NaiveBuffer(storage) buffer = Buffer(storage)
for i in range(20): for i in range(20):
naive_buffer.push(i) buffer.push(i)
assert storage.count() == 10 assert storage.count() == 10
assert len(set(naive_buffer.sample(10))) == 10 assert len(set(buffer.sample(10))) == 10
assert 0 not in naive_buffer.sample(10) assert 0 not in buffer.sample(10)
@pytest.mark.unittest
def test_rate_limit_push_sample():
storage = MemoryStorage(maxlen=10)
ratelimit = RateLimit(max_rate=5)
buffer = Buffer(storage).use(ratelimit.handler())
for i in range(10):
buffer.push(i)
assert storage.count() == 5
assert 5 not in buffer.sample(5)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册