提交 638f5110 编写于 作者: N niuyazhe 提交者: Xu Jingxin

feature(nyz): add naive priority experience replay

上级 d066372a
from .clone_object import clone_object
from .use_time_check import use_time_check
from .staleness_check import staleness_check
from .priority import priority
from typing import Callable, Any, List, Dict, Optional
import copy
import numpy as np
from ding.utils import SumSegmentTree, MinSegmentTree
class PriorityExperienceReplay:
def __init__(
self,
buffer: 'Buffer', # noqa
buffer_size: int,
IS_weight: bool = True,
priority_power_factor: float = 0.6,
IS_weight_power_factor: float = 0.4,
IS_weight_anneal_train_iter: int = int(1e5)
) -> None:
self.buffer = buffer
self.buffer_size = buffer_size
self.IS_weight = IS_weight
self.priority_power_factor = priority_power_factor
self.IS_weight_power_factor = IS_weight_power_factor
self.IS_weight_anneal_train_iter = IS_weight_anneal_train_iter
# Max priority till now, it's used to initizalize data's priority if "priority" is not passed in with the data.
self.max_priority = 1.0
# Capacity needs to be the power of 2.
capacity = int(np.power(2, np.ceil(np.log2(self.buffer_size))))
self.sum_tree = SumSegmentTree(capacity)
if self.IS_weight:
self.min_tree = MinSegmentTree(capacity)
self.delta_anneal = (1 - self.IS_weight_power_factor) / self.IS_weight_anneal_train_iter
self.pivot = 0
def push(self, chain: Callable, data: Any, meta: Optional[dict] = None, *args, **kwargs) -> None:
if meta is None:
meta = {'priority': self.max_priority}
else:
if 'priority' not in meta:
meta['priority'] = self.max_priority
meta['priority_idx'] = self.pivot
self._update_tree(meta['priority'], self.pivot)
self.pivot = (self.pivot + 1) % self.buffer_size
return chain(data, meta=meta, *args, **kwargs)
def sample(self, chain: Callable, size: int, *args, **kwargs) -> List[Any]:
# Divide [0, 1) into size intervals on average
intervals = np.array([i * 1.0 / size for i in range(size)])
# Uniformly sample within each interval
mass = intervals + np.random.uniform(size=(size, )) * 1. / size
# Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree)
mass *= self.sum_tree.reduce()
indices = [self.sum_tree.find_prefixsum_idx(m) for m in mass]
# TODO sample with indices
data = chain(size, return_index=True, return_meta=True, *args, **kwargs)
if self.IS_weight:
# Calculate max weight for normalizing IS
sum_tree_root = self.sum_tree.reduce()
p_min = self.min_tree.reduce() / sum_tree_root
buffer_count = self.buffer.count()
max_weight = (buffer_count * p_min) ** (-self.IS_weight_power_factor)
for i in range(len(data)):
meta = data[i][-1]
priority_idx = meta['priority_idx']
p_sample = self.sum_tree[priority_idx] / sum_tree_root
weight = (buffer_count * p_sample) ** (-self.IS_weight_power_factor)
meta['priority_IS'] = weight / max_weight
self.IS_weight_power_factor = min(1.0, self.IS_weight_power_factor + self.delta_anneal)
return data
def update(self, chain: Callable, index: str, data: Any, meta: dict, *args, **kwargs) -> None:
update_flag = chain(index, data, meta, *args, **kwargs)
if update_flag: # when update succeed
new_priority, idx = meta['priority'], meta['priority_idx']
assert new_priority >= 0, "new_priority should greater than 0, but found {}".format(new_priority)
new_priority += 1e-5 # Add epsilon to avoid priority == 0
self._update_tree(new_priority, idx)
self.max_priority = max(self.max_priority, new_priority)
def delete(self, chain: Callable, index: str, *args, **kwargs) -> None:
for (_, _, meta) in self.buffer.storage:
priority_idx = meta['priority_idx']
self.sum_tree[priority_idx] = self.sum_tree.neutral_element
self.min_tree[priority_idx] = self.min_tree.neutral_element
return chain(index, *args, **kwargs)
def clear(self, chain: Callable) -> None:
self.max_priority = 1.0
capacity = int(np.power(2, np.ceil(np.log2(self.buffer_size))))
self.sum_tree = SumSegmentTree(capacity)
if self.IS_weight:
self.min_tree = MinSegmentTree(capacity)
self.pivot = 0
chain()
def _update_tree(self, priority: float, idx: int) -> None:
weight = priority ** self.priority_power_factor
self.sum_tree[idx] = weight
self.min_tree[idx] = weight
def state_dict(self) -> Dict:
return {
'max_priority': self.max_priority,
'IS_weight_power_factor': self.IS_weight_power_factor,
'sumtree': self.sumtree,
'mintree': self.mintree,
}
def load_state_dict(self, _state_dict: Dict, deepcopy: bool = False) -> None:
for k, v in _state_dict.items():
if deepcopy:
setattr(self, '_{}'.format(k), copy.deepcopy(v))
else:
setattr(self, '_{}'.format(k), v)
def priority(*per_args, **per_kwargs):
per = PriorityExperienceReplay(*per_args, **per_kwargs)
def _priority(action: str, chain: Callable, *args, **kwargs) -> Any:
if action in ["push", "sample", "update", "delete", "clear"]:
return getattr(per, action)(chain, *args, **kwargs)
return chain(chain, *args, **kwargs)
return _priority
import pytest
import torch
from ding.worker.buffer import Buffer, DequeStorage
from ding.worker.buffer.middleware import clone_object, use_time_check, staleness_check
from ding.worker.buffer.middleware import clone_object, use_time_check, staleness_check, priority
@pytest.mark.unittest
......@@ -76,3 +76,29 @@ def test_staleness_check():
with pytest.raises(ValueError):
data = buffer.sample(size=N, replace=False, train_iter_sample_data=11)
assert buffer.count() == 2
@pytest.mark.unittest
def test_priority():
N = 5
buffer = Buffer(DequeStorage(maxlen=10))
buffer.use(priority(buffer, buffer_size=10, IS_weight=True))
for _ in range(N):
buffer.push(get_data())
assert buffer.count() == N
for _ in range(N):
buffer.push(get_data(), meta={'priority': 2.0})
assert buffer.count() == N + N
data = buffer.sample(size=N + N, replace=False)
assert len(data) == N + N
for (item, _, meta) in data:
assert set(meta.keys()).issuperset(set(['priority', 'priority_idx', 'priority_IS']))
meta['priority'] = 3.0
for item, index, meta in data:
buffer.update(index, item, meta)
data = buffer.sample(size=1)
assert data[0][2]['priority'] == 3.0
buffer.delete(data[0][1])
assert buffer.count() == N + N - 1
buffer.clear()
assert buffer.count() == 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册