From f764de31439be4ae21b749a735a732b7f4a7ff54 Mon Sep 17 00:00:00 2001 From: Xu Jingxin Date: Wed, 3 Nov 2021 18:29:18 +0800 Subject: [PATCH] Add ignore_insufficient --- ding/worker/buffer/buffer.py | 5 ++++- ding/worker/buffer/deque_buffer.py | 21 ++++++++++++++++++--- ding/worker/buffer/tests/test_buffer.py | 12 ++++++++++++ 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/ding/worker/buffer/buffer.py b/ding/worker/buffer/buffer.py index 14287873..69f43830 100644 --- a/ding/worker/buffer/buffer.py +++ b/ding/worker/buffer/buffer.py @@ -68,7 +68,8 @@ class Buffer: size: Optional[int] = None, indices: Optional[List[str]] = None, replace: bool = False, - range: Optional[slice] = None + range: Optional[slice] = None, + ignore_insufficient: bool = False ) -> List[BufferedData]: """ Overview: @@ -78,6 +79,8 @@ class Buffer: - indices (:obj:`Optional[List[str]]`): Sample with multiple indices. - replace (:obj:`bool`): If use replace is true, you may receive duplicated data from the buffer. - range (:obj:`slice`): Range slice. + - ignore_insufficient (:obj:`bool`): If ignore_insufficient is true, sampling more than buffer size + with no repetition will not cause an exception. Returns: - sample_data (:obj:`List[BufferedData]`): A list of data with length ``size``. diff --git a/ding/worker/buffer/deque_buffer.py b/ding/worker/buffer/deque_buffer.py index fca73d43..3d55cae0 100644 --- a/ding/worker/buffer/deque_buffer.py +++ b/ding/worker/buffer/deque_buffer.py @@ -24,7 +24,8 @@ class DequeBuffer(Buffer): size: Optional[int] = None, indices: Optional[List[str]] = None, replace: bool = False, - range: Optional[slice] = None + range: Optional[slice] = None, + ignore_insufficient: bool = False ) -> List[BufferedData]: storage = self.storage if range: @@ -32,11 +33,25 @@ class DequeBuffer(Buffer): assert size or indices, "One of size and indices must not be empty." if (size and indices) and (size != len(indices)): raise AssertionError("Size and indices length must be equal.") + if not size: + size = len(indices) + value_error = None + sampled_data = [] if indices: - sampled_data = filter(lambda item: item.index in indices, self.storage) + sampled_data = list(filter(lambda item: item.index in indices, self.storage)) else: - sampled_data = random.choices(storage, k=size) if replace else random.sample(storage, k=size) + if replace: + sampled_data = random.choices(storage, k=size) + else: + try: + sampled_data = random.sample(storage, k=size) + except ValueError as e: + value_error = e + + if not ignore_insufficient and (value_error or len(sampled_data) != size): + raise ValueError("There are less than {} data in buffer".format(size)) + return sampled_data @apply_middleware("update") diff --git a/ding/worker/buffer/tests/test_buffer.py b/ding/worker/buffer/tests/test_buffer.py index 55e545b9..013c1c21 100644 --- a/ding/worker/buffer/tests/test_buffer.py +++ b/ding/worker/buffer/tests/test_buffer.py @@ -153,3 +153,15 @@ def test_update_delete(): [item] = buf.sample(1) buf.delete(item.index) assert buf.count() == 0 + + +@pytest.mark.unittest +def test_ignore_insufficient(): + buffer = DequeBuffer(size=10) + for i in range(2): + buffer.push(i) + + with pytest.raises(ValueError): + buffer.sample(3) + data = buffer.sample(3, ignore_insufficient=True) + assert len(data) == 0 -- GitLab