提交 7e9e4e88 编写于 作者: N niuyazhe

feature(nyz): add sample_range arg in replay buffer

上级 608fee41
......@@ -10,6 +10,13 @@ from ding.utils.autolog import TickTime
from .utils import UsedDataRemover, generate_id, SampledDataAttrMonitor, PeriodicThruputMonitor, ThruputController
def to_positive_index(idx: Union[int, None], size: int) -> int:
if idx is None or idx >= 0:
return idx
else:
return size + idx
@BUFFER_REGISTRY.register('advanced')
class AdvancedReplayBuffer(IBuffer):
r"""
......@@ -206,13 +213,15 @@ class AdvancedReplayBuffer(IBuffer):
if self._enable_track_used_data:
self._used_data_remover.close()
def sample(self, size: int, cur_learner_iter: int) -> Optional[list]:
r"""
def sample(self, size: int, cur_learner_iter: int, sample_range: slice = None) -> Optional[list]:
"""
Overview:
Sample data with length ``size``.
Arguments:
- size (:obj:`int`): The number of the data that will be sampled.
- cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness.
- sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \
means only sample among the last 10 data
Returns:
- sample_data (:obj:`list`): A list of data with length ``size``
ReturnsKeys:
......@@ -235,7 +244,7 @@ class AdvancedReplayBuffer(IBuffer):
)
return None
with self._lock:
indices = self._get_indices(size)
indices = self._get_indices(size, sample_range)
result = self._sample_with_indices(indices, cur_learner_iter)
# Deepcopy ``result``'s same indice datas in case ``self._get_indices`` may get datas with
# the same indices, i.e. the same datas would be sampled afterwards.
......@@ -498,7 +507,7 @@ class AdvancedReplayBuffer(IBuffer):
# only the data passes all the check functions, would the check return True
return all([fn(d) for fn in self.check_list])
def _get_indices(self, size: int) -> list:
def _get_indices(self, size: int, sample_range: slice = None) -> list:
r"""
Overview:
Get the sample index list according to the priority probability.
......@@ -511,8 +520,16 @@ class AdvancedReplayBuffer(IBuffer):
intervals = np.array([i * 1.0 / size for i in range(size)])
# Uniformly sample within each interval
mass = intervals + np.random.uniform(size=(size, )) * 1. / size
if sample_range is None:
# Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree)
mass *= self._sum_tree.reduce()
else:
# Rescale to [a, b)
start = to_positive_index(sample_range.start, self._replay_buffer_size)
end = to_positive_index(sample_range.stop, self._replay_buffer_size)
a = self._sum_tree.reduce(0, start)
b = self._sum_tree.reduce(0, end)
mass = mass * (b - a) + a
# Find prefix sum index to sample with probability
return [self._sum_tree.find_prefixsum_idx(m) for m in mass]
......
......@@ -95,14 +95,16 @@ class NaiveReplayBuffer(IBuffer):
else:
self._append(data, cur_collector_envstep)
def sample(self, size: int, cur_learner_iter: int) -> Optional[list]:
r"""
def sample(self, size: int, cur_learner_iter: int, sample_range: slice = None) -> Optional[list]:
"""
Overview:
Sample data with length ``size``.
Arguments:
- size (:obj:`int`): The number of the data that will be sampled.
- cur_learner_iter (:obj:`int`): Learner's current iteration. \
Not used in naive buffer, but preserved for compatibility.
- sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \
means only sample among the last 10 data
Returns:
- sample_data (:obj:`list`): A list of data with length ``size``.
"""
......@@ -112,7 +114,7 @@ class NaiveReplayBuffer(IBuffer):
if not can_sample:
return None
with self._lock:
indices = self._get_indices(size)
indices = self._get_indices(size, sample_range)
result = self._sample_with_indices(indices, cur_learner_iter)
return result
......@@ -234,12 +236,14 @@ class NaiveReplayBuffer(IBuffer):
"""
self.close()
def _get_indices(self, size: int) -> list:
def _get_indices(self, size: int, sample_range: slice = None) -> list:
r"""
Overview:
Get the sample index list.
Arguments:
- size (:obj:`int`): The number of the data that will be sampled
- sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \
means only sample among the last 10 data
Returns:
- index_list (:obj:`list`): A list including all the sample indices, whose length should equal to ``size``.
"""
......@@ -248,7 +252,11 @@ class NaiveReplayBuffer(IBuffer):
tail = self._replay_buffer_size
else:
tail = self._tail
if sample_range is None:
indices = list(np.random.choice(a=tail, size=size, replace=False))
else:
indices = list(range(tail))[sample_range]
indices = list(np.random.choice(indices, size=size, replace=False))
return indices
def _sample_with_indices(self, indices: List[int], cur_learner_iter: int) -> list:
......
......@@ -129,6 +129,13 @@ class TestAdvancedBuffer:
if v > advanced_buffer._max_use:
assert advanced_buffer._data[k] is None
for _ in range(64):
data = generate_data()
data['priority'] = None
advanced_buffer.push(data, 0)
batch = advanced_buffer.sample(10, 0, sample_range=slice(-20, -2))
assert len(batch) == 10
def test_head_tail(self):
buffer_cfg = deep_merge_dicts(
AdvancedReplayBuffer.default_config(), EasyDict(dict(replay_buffer_size=64, max_use=4))
......
......@@ -43,7 +43,15 @@ class TestNaiveBuffer:
for _ in range(64):
naive_buffer.push(generate_data(), 0)
batch = naive_buffer.sample(32, 0)
assert (len(batch) == 32)
assert len(batch) == 32
last_one_batch = naive_buffer.sample(1, 0, sample_range=slice(-1, None))
assert len(last_one_batch) == 1
assert last_one_batch[0] == naive_buffer._data[-1]
batch = naive_buffer.sample(5, 0, sample_range=slice(-10, -2))
sample_range_data = naive_buffer._data[-10:-2]
assert len(batch) == 5
for b in batch:
assert any([b['data_id'] == d['data_id'] for d in sample_range_data])
# test clear
naive_buffer.clear()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册