提交 f764de31 编写于 作者: X Xu Jingxin

Add ignore_insufficient

上级 6e93ff58
......@@ -68,7 +68,8 @@ class Buffer:
size: Optional[int] = None,
indices: Optional[List[str]] = None,
replace: bool = False,
range: Optional[slice] = None
range: Optional[slice] = None,
ignore_insufficient: bool = False
) -> List[BufferedData]:
"""
Overview:
......@@ -78,6 +79,8 @@ class Buffer:
- indices (:obj:`Optional[List[str]]`): Sample with multiple indices.
- replace (:obj:`bool`): If use replace is true, you may receive duplicated data from the buffer.
- range (:obj:`slice`): Range slice.
- ignore_insufficient (:obj:`bool`): If ignore_insufficient is true, sampling more than buffer size
with no repetition will not cause an exception.
Returns:
- sample_data (:obj:`List[BufferedData]`):
A list of data with length ``size``.
......
......@@ -24,7 +24,8 @@ class DequeBuffer(Buffer):
size: Optional[int] = None,
indices: Optional[List[str]] = None,
replace: bool = False,
range: Optional[slice] = None
range: Optional[slice] = None,
ignore_insufficient: bool = False
) -> List[BufferedData]:
storage = self.storage
if range:
......@@ -32,11 +33,25 @@ class DequeBuffer(Buffer):
assert size or indices, "One of size and indices must not be empty."
if (size and indices) and (size != len(indices)):
raise AssertionError("Size and indices length must be equal.")
if not size:
size = len(indices)
value_error = None
sampled_data = []
if indices:
sampled_data = filter(lambda item: item.index in indices, self.storage)
sampled_data = list(filter(lambda item: item.index in indices, self.storage))
else:
sampled_data = random.choices(storage, k=size) if replace else random.sample(storage, k=size)
if replace:
sampled_data = random.choices(storage, k=size)
else:
try:
sampled_data = random.sample(storage, k=size)
except ValueError as e:
value_error = e
if not ignore_insufficient and (value_error or len(sampled_data) != size):
raise ValueError("There are less than {} data in buffer".format(size))
return sampled_data
@apply_middleware("update")
......
......@@ -153,3 +153,15 @@ def test_update_delete():
[item] = buf.sample(1)
buf.delete(item.index)
assert buf.count() == 0
@pytest.mark.unittest
def test_ignore_insufficient():
buffer = DequeBuffer(size=10)
for i in range(2):
buffer.push(i)
with pytest.raises(ValueError):
buffer.sample(3)
data = buffer.sample(3, ignore_insufficient=True)
assert len(data) == 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册