diff --git a/ding/worker/buffer/buffer.py b/ding/worker/buffer/buffer.py index d4b54b540eb3be284bdfb7794580cfb161b24d5f..c54f480a062a9e28baaf582e137e89f3ae42da69 100644 --- a/ding/worker/buffer/buffer.py +++ b/ding/worker/buffer/buffer.py @@ -8,6 +8,13 @@ def apply_middleware(func_name: str): def wrap_func(base_func: Callable): def handler(buffer, *args, **kwargs): + """ + The real processing starts here, we apply the middlewares one by one, + each middleware will receive a `next` function, which is an executor of next + middleware. You can change the input arguments to the `next` middleware, and you + also can get the return value from the next middleware, so you have the + maximum freedom to choose at what stage to implement your method. + """ def wrap_handler(middlewares, *args, **kwargs): if len(middlewares) == 0: @@ -49,7 +56,7 @@ class Buffer: self.storage.append(data) @apply_middleware("sample") - def sample(self, size: int, replace: bool = False) -> List[Any]: + def sample(self, size: int, replace: bool = False, range: slice = None) -> List[Any]: """ Overview: Sample data with length ``size``, this function may be wrapped by middlewares. @@ -58,7 +65,7 @@ class Buffer: Returns: - sample_data (:obj:`list`): A list of data with length ``size``. """ - return self.storage.sample(size, replace) + return self.storage.sample(size, replace=replace, range=range) @apply_middleware("clear") def clear(self) -> None: diff --git a/ding/worker/buffer/memory_storage.py b/ding/worker/buffer/memory_storage.py index 70c343cdf8b01a19ff4d0933bb729e74ce9f3937..9525890b54fdfc4577ee43c6cb52150495d65044 100644 --- a/ding/worker/buffer/memory_storage.py +++ b/ding/worker/buffer/memory_storage.py @@ -2,8 +2,8 @@ from typing import Any, List from collections import deque from operator import itemgetter from ding.worker.buffer import Storage -import random import numpy as np +import itertools class MemoryStorage(Storage): @@ -17,8 +17,11 @@ class MemoryStorage(Storage): def get(self, indices: List[int]) -> List[Any]: return itemgetter(*indices)(self.storage) - def sample(self, size: int, replace: bool = False) -> List[Any]: - return np.random.choice(self.storage, size, replace=replace) + def sample(self, size: int, replace: bool = False, range: slice = None) -> List[Any]: + storage = self.storage + if range: + storage = list(itertools.islice(self.storage, range.start, range.stop, range.step)) + return np.random.choice(storage, size, replace=replace) def count(self) -> int: return len(self.storage) diff --git a/ding/worker/buffer/storage.py b/ding/worker/buffer/storage.py index 623564f214d842744c9e8a8661b062b2a838d9d4..641015524f074b6ce05d4e4d4603b9d54ff2c959 100644 --- a/ding/worker/buffer/storage.py +++ b/ding/worker/buffer/storage.py @@ -13,7 +13,7 @@ class Storage: raise NotImplementedError @abstractmethod - def sample(self, size: int, replace: bool = False) -> List[Any]: + def sample(self, size: int, replace: bool = False, range: slice = None) -> List[Any]: raise NotImplementedError @abstractmethod diff --git a/ding/worker/buffer/tests/test_buffer.py b/ding/worker/buffer/tests/test_buffer.py index 756f25de08685c7dc0daf47762a0d5d36ab1c0b3..8d3d8b2f70fd34075edcdd71976bcaef378a5ce9 100644 --- a/ding/worker/buffer/tests/test_buffer.py +++ b/ding/worker/buffer/tests/test_buffer.py @@ -75,6 +75,13 @@ def test_naive_push_sample(): assert storage.count() == 5 assert len(buffer.sample(10, replace=True)) == 10 + # Test slicing + buffer.clear() + for i in range(10): + buffer.push(i) + assert len(buffer.sample(5, range=slice(5, 10))) == 5 + assert 0 not in buffer.sample(5, range=slice(5, 10)) + @pytest.mark.unittest def test_rate_limit_push_sample():