提交 e9246b74 编写于 作者: X Xu Jingxin

Test slicing

上级 d40bdf16
......@@ -8,6 +8,13 @@ def apply_middleware(func_name: str):
def wrap_func(base_func: Callable):
def handler(buffer, *args, **kwargs):
"""
The real processing starts here, we apply the middlewares one by one,
each middleware will receive a `next` function, which is an executor of next
middleware. You can change the input arguments to the `next` middleware, and you
also can get the return value from the next middleware, so you have the
maximum freedom to choose at what stage to implement your method.
"""
def wrap_handler(middlewares, *args, **kwargs):
if len(middlewares) == 0:
......@@ -49,7 +56,7 @@ class Buffer:
self.storage.append(data)
@apply_middleware("sample")
def sample(self, size: int, replace: bool = False) -> List[Any]:
def sample(self, size: int, replace: bool = False, range: slice = None) -> List[Any]:
"""
Overview:
Sample data with length ``size``, this function may be wrapped by middlewares.
......@@ -58,7 +65,7 @@ class Buffer:
Returns:
- sample_data (:obj:`list`): A list of data with length ``size``.
"""
return self.storage.sample(size, replace)
return self.storage.sample(size, replace=replace, range=range)
@apply_middleware("clear")
def clear(self) -> None:
......
......@@ -2,8 +2,8 @@ from typing import Any, List
from collections import deque
from operator import itemgetter
from ding.worker.buffer import Storage
import random
import numpy as np
import itertools
class MemoryStorage(Storage):
......@@ -17,8 +17,11 @@ class MemoryStorage(Storage):
def get(self, indices: List[int]) -> List[Any]:
return itemgetter(*indices)(self.storage)
def sample(self, size: int, replace: bool = False) -> List[Any]:
return np.random.choice(self.storage, size, replace=replace)
def sample(self, size: int, replace: bool = False, range: slice = None) -> List[Any]:
storage = self.storage
if range:
storage = list(itertools.islice(self.storage, range.start, range.stop, range.step))
return np.random.choice(storage, size, replace=replace)
def count(self) -> int:
return len(self.storage)
......
......@@ -13,7 +13,7 @@ class Storage:
raise NotImplementedError
@abstractmethod
def sample(self, size: int, replace: bool = False) -> List[Any]:
def sample(self, size: int, replace: bool = False, range: slice = None) -> List[Any]:
raise NotImplementedError
@abstractmethod
......
......@@ -75,6 +75,13 @@ def test_naive_push_sample():
assert storage.count() == 5
assert len(buffer.sample(10, replace=True)) == 10
# Test slicing
buffer.clear()
for i in range(10):
buffer.push(i)
assert len(buffer.sample(5, range=slice(5, 10))) == 5
assert 0 not in buffer.sample(5, range=slice(5, 10))
@pytest.mark.unittest
def test_rate_limit_push_sample():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册