未验证 提交 a490729f 编写于 作者: X Xu Jingxin 提交者: GitHub

feature(xjx): refactor buffer (#129)

* Init base buffer and storage

* Use ratelimit as middleware

* Pass style check

* Keep the return original return value

* Add buffer.view

* Add replace flag on sample, rewrite middleware processing

* Test slicing

* Add buffer copy middleware

* Add update/delete api in buffer, rename middleware

* Implement update and delete api of buffer

* add naive use time count middleware in buffer

* Rename next to chain

* feature(nyz): add staleness check middleware and polish buffer

* feature(nyz): add naive priority experience replay

* Sample by indices

* Combine buffer and storage layers

* Support indices when deleting items from the queue

* Use dataclass to save buffered data, remove return_index and return_meta

* Add ignore_insufficient

* polish(nyz): add return index in push and copy same data in sample

* Drop useless import

* Fix sample with indices, ensure return size is equal to input size or indices size

* Make sure sampled data in buffer is different from each other

* Support sample by grouped meta key

* Support sample by rolling window

* Add import/export data in buffer

* Padding after sampling from buffer

* Polish use_time_check

* Use buffer as dataset

* Set collate_fn in buffer test

* feature(nyz): add deque buffer compatibility wrapper and demo

* polish(nyz): polish code style and add pong dqn new deque buffer demo

* feature(nyz): add use_time_count compatibility in wrapper

* feature(nyz): add priority replay buffer compatibility in wrapper

* Improve performance of buffer.update

* polish(nyz): add priority max limit and correct flake8

* Use __call__ to rewrite middleware

* Rewrite buffer index

* Fix buffer delete

* Skip first item

* Rewrite buffer delete

* Use caller

* Use caller in priority

* Add group sample
Co-authored-by: Nniuyazhe <niuyazhe@sensetime.com>
上级 a7de696a
......@@ -3,3 +3,4 @@ from .learner import *
from .replay_buffer import *
from .coordinator import *
from .adapter import *
from .buffer import *
from .buffer import Buffer, apply_middleware, BufferedData
from .deque_buffer import DequeBuffer
from .deque_buffer_wrapper import DequeBufferWrapper
from abc import abstractmethod
from typing import Any, List, Optional, Union, Callable
import copy
from dataclasses import dataclass
def apply_middleware(func_name: str):
def wrap_func(base_func: Callable):
def handler(buffer, *args, **kwargs):
"""
Overview:
The real processing starts here, we apply the middleware one by one,
each middleware will receive next `chained` function, which is an executor of next
middleware. You can change the input arguments to the next `chained` middleware, and you
also can get the return value from the next middleware, so you have the
maximum freedom to choose at what stage to implement your method.
"""
def wrap_handler(middleware, *args, **kwargs):
if len(middleware) == 0:
return base_func(buffer, *args, **kwargs)
def chain(*args, **kwargs):
return wrap_handler(middleware[1:], *args, **kwargs)
func = middleware[0]
return func(func_name, chain, *args, **kwargs)
return wrap_handler(buffer.middleware, *args, **kwargs)
return handler
return wrap_func
@dataclass
class BufferedData:
data: Any
index: str
meta: dict
class Buffer:
"""
Buffer is an abstraction of device storage, third-party services or data structures,
For example, memory queue, sum-tree, redis, or di-store.
"""
def __init__(self) -> None:
self.middleware = []
@abstractmethod
def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData:
"""
Overview:
Push data and it's meta information in buffer.
Arguments:
- data (:obj:`Any`): The data which will be pushed into buffer.
- meta (:obj:`dict`): Meta information, e.g. priority, count, staleness.
Returns:
- buffered_data (:obj:`BufferedData`): The pushed data.
"""
raise NotImplementedError
@abstractmethod
def sample(
self,
size: Optional[int] = None,
indices: Optional[List[str]] = None,
replace: bool = False,
sample_range: Optional[slice] = None,
ignore_insufficient: bool = False,
groupby: str = None,
rolling_window: int = None
) -> Union[List[BufferedData], List[List[BufferedData]]]:
"""
Overview:
Sample data with length ``size``.
Arguments:
- size (:obj:`Optional[int]`): The number of the data that will be sampled.
- indices (:obj:`Optional[List[str]]`): Sample with multiple indices.
- replace (:obj:`bool`): If use replace is true, you may receive duplicated data from the buffer.
- sample_range (:obj:`slice`): Sample range slice.
- ignore_insufficient (:obj:`bool`): If ignore_insufficient is true, sampling more than buffer size
with no repetition will not cause an exception.
- groupby (:obj:`str`): Groupby key in meta.
- rolling_window (:obj:`int`): Return batches of window size.
Returns:
- sample_data (:obj:`Union[List[BufferedData], List[List[BufferedData]]]`):
A list of data with length ``size``, may be nested if groupby or rolling_window is set.
"""
raise NotImplementedError
@abstractmethod
def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] = None) -> bool:
"""
Overview:
Update data and meta by index
Arguments:
- index (:obj:`str`): Index of data.
- data (:obj:`any`): Pure data.
- meta (:obj:`dict`): Meta information.
Returns:
- success (:obj:`bool`): Success or not, if data with the index not exist in buffer, return false.
"""
raise NotImplementedError
@abstractmethod
def batch_update(
self,
indices: List[str],
datas: Optional[List[Optional[Any]]] = None,
metas: Optional[List[Optional[dict]]] = None
) -> None:
"""
Overview:
Batch update data and meta by indices, maybe useful in some data architectures.
Arguments:
- indices (:obj:`List[str]`): Index of data.
- datas (:obj:`Optional[List[Optional[Any]]]`): Pure data.
- metas (:obj:`Optional[List[Optional[dict]]]`): Meta information.
"""
raise NotImplementedError
@abstractmethod
def delete(self, index: str):
"""
Overview:
Delete one data sample by index
Arguments:
- index (:obj:`str`): Index
"""
raise NotImplementedError
@abstractmethod
def count(self) -> int:
raise NotImplementedError
@abstractmethod
def clear(self) -> None:
raise NotImplementedError
@abstractmethod
def get(self, idx: int) -> BufferedData:
"""
Overview:
Get item by subscript index
Arguments:
- idx (:obj:`int`): Subscript index
Returns:
- buffered_data (:obj:`BufferedData`): Item from buffer
"""
raise NotImplementedError
def use(self, func: Callable) -> "Buffer":
r"""
Overview:
Use algorithm middleware to modify the behavior of the buffer.
Every middleware should be a callable function, it will receive three argument parts, including:
1. The buffer instance, you can use this instance to visit every thing of the buffer,
including the storage.
2. The functions called by the user, there are three methods named `push`, `sample` and `clear`,
so you can use these function name to decide which action to choose.
3. The remaining arguments passed by the user to the original function, will be passed in *args.
Each middleware handler should return two parts of the value, including:
1. The first value is `done` (True or False), if done==True, the middleware chain will stop immediately,
no more middleware will be executed during this execution
2. The remaining values, will be passed to the next middleware or the default function in the buffer.
Arguments:
- func (:obj:`Callable`): The middleware handler
Returns:
- buffer (:obj:`Buffer`): The instance self
"""
self.middleware.append(func)
return self
def view(self) -> "Buffer":
r"""
Overview:
A view is a new instance of buffer, with a deepcopy of every property except the storage.
The storage is shared among all the buffer instances.
Returns:
- buffer (:obj:`Buffer`): The instance self
"""
return copy.copy(self)
def __copy__(self) -> "Buffer":
raise NotImplementedError
def __len__(self) -> int:
return self.count()
def __getitem__(self, idx: int) -> BufferedData:
return self.get(idx)
from typing import Any, Iterable, List, Optional, Tuple, Union
from collections import defaultdict, deque, OrderedDict
from ding.worker.buffer import Buffer, apply_middleware, BufferedData
from ding.worker.buffer.utils import fastcopy
import itertools
import random
import uuid
import logging
class BufferIndex():
"""
Overview:
Save index string and offset in key value pair.
"""
def __init__(self, maxlen: int, *args, **kwargs):
self.maxlen = maxlen
self.__map = OrderedDict(*args, **kwargs)
self._last_key = next(reversed(self.__map)) if len(self) > 0 else None
self._cumlen = len(self.__map)
def get(self, key: str) -> int:
value = self.__map[key]
value = value % self._cumlen + min(0, (self.maxlen - self._cumlen))
return value
def __len__(self) -> int:
return len(self.__map)
def has(self, key: str) -> bool:
return key in self.__map
def append(self, key: str):
self.__map[key] = self.__map[self._last_key] + 1 if self._last_key else 0
self._last_key = key
self._cumlen += 1
if len(self) > self.maxlen:
self.__map.popitem(last=False)
def clear(self):
self.__map = OrderedDict()
self._last_key = None
self._cumlen = 0
class DequeBuffer(Buffer):
def __init__(self, size: int) -> None:
super().__init__()
self.storage = deque(maxlen=size)
# Meta index is a dict which use deque as values
self.indices = BufferIndex(maxlen=size)
self.meta_index = {}
@apply_middleware("push")
def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData:
return self._push(data, meta)
@apply_middleware("sample")
def sample(
self,
size: Optional[int] = None,
indices: Optional[List[str]] = None,
replace: bool = False,
sample_range: Optional[slice] = None,
ignore_insufficient: bool = False,
groupby: str = None,
rolling_window: int = None
) -> Union[List[BufferedData], List[List[BufferedData]]]:
storage = self.storage
if sample_range:
storage = list(itertools.islice(self.storage, sample_range.start, sample_range.stop, sample_range.step))
# Size and indices
assert size or indices, "One of size and indices must not be empty."
if (size and indices) and (size != len(indices)):
raise AssertionError("Size and indices length must be equal.")
if not size:
size = len(indices)
# Indices and groupby
assert not (indices and groupby), "Cannot use groupby and indicex at the same time."
# Groupby and rolling_window
assert not (groupby and rolling_window), "Cannot use groupby and rolling_window at the same time."
assert not (indices and rolling_window), "Cannot use indices and rolling_window at the same time."
value_error = None
sampled_data = []
if indices:
indices_set = set(indices)
hashed_data = filter(lambda item: item.index in indices_set, storage)
hashed_data = map(lambda item: (item.index, item), hashed_data)
hashed_data = dict(hashed_data)
# Re-sample and return in indices order
sampled_data = [hashed_data[index] for index in indices]
elif groupby:
sampled_data = self._sample_by_group(size=size, groupby=groupby, replace=replace, storage=storage)
elif rolling_window:
sampled_data = self._sample_by_rolling_window(
size=size, replace=replace, rolling_window=rolling_window, storage=storage
)
else:
if replace:
sampled_data = random.choices(storage, k=size)
else:
try:
sampled_data = random.sample(storage, k=size)
except ValueError as e:
value_error = e
if value_error or len(sampled_data) != size:
if ignore_insufficient:
logging.warning(
"Sample operation is ignored due to data insufficient, current buffer is {} while sample is {}".
format(self.count(), size)
)
else:
raise ValueError("There are less than {} records/groups in buffer({})".format(size, self.count()))
sampled_data = self._independence(sampled_data)
return sampled_data
@apply_middleware("update")
def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] = None) -> bool:
if not self.indices.has(index):
return False
i = self.indices.get(index)
item = self.storage[i]
if data is not None:
item.data = data
if meta is not None:
item.meta = meta
for key in self.meta_index:
self.meta_index[key][i] = meta[key] if key in meta else None
return True
@apply_middleware("delete")
def delete(self, indices: Union[str, Iterable[str]]) -> None:
if isinstance(indices, str):
indices = [indices]
del_idx = []
for index in indices:
if self.indices.has(index):
del_idx.append(self.indices.get(index))
if len(del_idx) == 0:
return
del_idx = sorted(del_idx, reverse=True)
for idx in del_idx:
del self.storage[idx]
remain_indices = [item.index for item in self.storage]
key_value_pairs = zip(remain_indices, range(len(indices)))
self.indices = BufferIndex(self.storage.maxlen, key_value_pairs)
def count(self) -> int:
return len(self.storage)
def get(self, idx: int) -> BufferedData:
return self.storage[idx]
@apply_middleware("clear")
def clear(self) -> None:
self.storage.clear()
self.indices.clear()
self.meta_index = {}
def import_data(self, data_with_meta: List[Tuple[Any, dict]]) -> None:
for data, meta in data_with_meta:
self._push(data, meta)
def export_data(self) -> List[BufferedData]:
return list(self.storage)
def _push(self, data: Any, meta: Optional[dict] = None) -> BufferedData:
index = uuid.uuid1().hex
if meta is None:
meta = {}
buffered = BufferedData(data=data, index=index, meta=meta)
self.storage.append(buffered)
self.indices.append(index)
# Add meta index
for key in self.meta_index:
self.meta_index[key].append(meta[key] if key in meta else None)
return buffered
def _independence(
self, buffered_samples: Union[List[BufferedData], List[List[BufferedData]]]
) -> Union[List[BufferedData], List[List[BufferedData]]]:
"""
Overview:
Make sure that each record is different from each other, but remember that this function
is different from clone_object. You may change the data in the buffer by modifying a record.
Arguments:
- buffered_samples (:obj:`Union[List[BufferedData], List[List[BufferedData]]]`) Sampled data,
can be nested if groupby or rolling_window has been set.
"""
if len(buffered_samples) == 0:
return buffered_samples
occurred = defaultdict(int)
for i, buffered in enumerate(buffered_samples):
if isinstance(buffered, list):
sampled_list = buffered
# Loop over nested samples
for j, buffered in enumerate(sampled_list):
occurred[buffered.index] += 1
if occurred[buffered.index] > 1:
sampled_list[j] = fastcopy.copy(buffered)
elif isinstance(buffered, BufferedData):
occurred[buffered.index] += 1
if occurred[buffered.index] > 1:
buffered_samples[i] = fastcopy.copy(buffered)
else:
raise Exception("Get unexpected buffered type {}".format(type(buffered)))
return buffered_samples
def _sample_by_group(self,
size: int,
groupby: str,
replace: bool = False,
storage: deque = None) -> List[List[BufferedData]]:
"""
Overview:
Sampling by `group` instead of records, the result will be a collection
of lists with a length of `size`, but the length of each list may be different from other lists.
"""
if storage is None:
storage = self.storage
if groupby not in self.meta_index:
self._create_index(groupby)
meta_indices = list(set(self.meta_index[groupby]))
sampled_groups = []
if replace:
sampled_groups = random.choices(meta_indices, k=size)
else:
try:
sampled_groups = random.sample(meta_indices, k=size)
except ValueError:
raise ValueError("There are less than {} groups in buffer({} groups)".format(size, len(meta_indices)))
sampled_data = defaultdict(list)
for buffered in storage:
meta_value = buffered.meta[groupby] if groupby in buffered.meta else None
if meta_value in sampled_groups:
sampled_data[buffered.meta[groupby]].append(buffered)
return list(sampled_data.values())
def _sample_by_rolling_window(
self,
size: Optional[int] = None,
replace: bool = False,
rolling_window: int = None,
storage: deque = None
) -> List[List[BufferedData]]:
if storage is None:
storage = self.storage
if replace:
sampled_indices = random.choices(range(len(storage)), k=size)
else:
try:
sampled_indices = random.sample(range(len(storage)), k=size)
except ValueError as e:
pass
sampled_data = []
for idx in sampled_indices:
slice_ = list(itertools.islice(storage, idx, idx + rolling_window))
sampled_data.append(slice_)
return sampled_data
def _create_index(self, meta_key: str):
self.meta_index[meta_key] = deque(maxlen=self.storage.maxlen)
for data in self.storage:
self.meta_index[meta_key].append(data.meta[meta_key] if meta_key in data.meta else None)
def __iter__(self) -> deque:
return iter(self.storage)
def __copy__(self) -> "DequeBuffer":
buffer = type(self)(size=self.storage.maxlen)
buffer.storage = self.storage
return buffer
from typing import Optional
import copy
from easydict import EasyDict
import numpy as np
from ding.worker.buffer import DequeBuffer
from ding.worker.buffer.middleware import use_time_check, PriorityExperienceReplay
from ding.utils import BUFFER_REGISTRY
@BUFFER_REGISTRY.register('deque')
class DequeBufferWrapper(object):
@classmethod
def default_config(cls: type) -> EasyDict:
cfg = EasyDict(copy.deepcopy(cls.config))
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg
config = dict(
replay_buffer_size=10000,
max_use=float("inf"),
train_iter_per_log=100,
priority=False,
priority_IS_weight=False,
priority_power_factor=0.6,
IS_weight_power_factor=0.4,
IS_weight_anneal_train_iter=int(1e5),
priority_max_limit=1000,
)
def __init__(
self,
cfg: EasyDict,
tb_logger: Optional[object] = None,
exp_name: str = 'default_experiement',
instance_name: str = 'buffer'
) -> None:
self.cfg = cfg
self.priority_max_limit = cfg.priority_max_limit
self.name = '{}_iter'.format(instance_name)
self.tb_logger = tb_logger
self.buffer = DequeBuffer(size=cfg.replay_buffer_size)
self.last_log_train_iter = -1
# use_count middleware
if self.cfg.max_use != float("inf"):
self.buffer.use(use_time_check(self.buffer, max_use=self.cfg.max_use))
# priority middleware
if self.cfg.priority:
self.buffer.use(
PriorityExperienceReplay(
self.buffer,
self.cfg.replay_buffer_size,
IS_weight=self.cfg.priority_IS_weight,
priority_power_factor=self.cfg.priority_power_factor,
IS_weight_power_factor=self.cfg.IS_weight_power_factor,
IS_weight_anneal_train_iter=self.cfg.IS_weight_anneal_train_iter
)
)
self.last_sample_index = None
self.last_sample_meta = None
def sample(self, size: int, train_iter: int):
output = self.buffer.sample(size=size, ignore_insufficient=True)
if len(output) > 0:
if self.last_log_train_iter == -1 or train_iter - self.last_log_train_iter >= self.cfg.train_iter_per_log:
meta = [o.meta for o in output]
if self.cfg.max_use != float("inf"):
use_count_avg = np.mean([m['use_count'] for m in meta])
self.tb_logger.add_scalar('{}/use_count_avg'.format(self.name), use_count_avg, train_iter)
if self.cfg.priority:
self.last_sample_index = [o.index for o in output]
self.last_sample_meta = meta
priority_list = [m['priority'] for m in meta]
priority_avg = np.mean(priority_list)
priority_max = np.max(priority_list)
self.tb_logger.add_scalar('{}/priority_avg'.format(self.name), priority_avg, train_iter)
self.tb_logger.add_scalar('{}/priority_max'.format(self.name), priority_max, train_iter)
self.tb_logger.add_scalar('{}/buffer_data_count'.format(self.name), self.buffer.count(), train_iter)
data = [o.data for o in output]
if self.cfg.priority_IS_weight:
IS = [o.meta['priority_IS'] for o in output]
for i in range(len(data)):
data[i]['IS'] = IS[i]
return data
else:
return None
def push(self, data, cur_collector_envstep: int = -1) -> None:
for d in data:
meta = {}
if self.cfg.priority and 'priority' in d:
init_priority = d.pop('priority')
meta['priority'] = init_priority
self.buffer.push(d, meta=meta)
def update(self, meta: dict) -> None:
if not self.cfg.priority:
return
if self.last_sample_index is None:
return
new_meta = self.last_sample_meta
for m, p in zip(new_meta, meta['priority']):
m['priority'] = min(self.priority_max_limit, p)
for idx, m in zip(self.last_sample_index, new_meta):
self.buffer.update(idx, data=None, meta=m)
self.last_sample_index = None
self.last_sample_meta = None
from .clone_object import clone_object
from .use_time_check import use_time_check
from .staleness_check import staleness_check
from .priority import PriorityExperienceReplay
from .padding import padding
from .group_sample import group_sample
from typing import Callable, Any, List, Union
from ding.worker.buffer import BufferedData
from ding.worker.buffer.utils import fastcopy
def clone_object():
"""
This middleware freezes the objects saved in memory buffer as a copy,
try this middleware when you need to keep the object unchanged in buffer, and modify
the object after sampling it (usuallly in multiple threads)
"""
def push(chain: Callable, data: Any, *args, **kwargs) -> BufferedData:
data = fastcopy.copy(data)
return chain(data, *args, **kwargs)
def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]:
data = chain(*args, **kwargs)
return fastcopy.copy(data)
def _clone_object(action: str, chain: Callable, *args, **kwargs):
if action == "push":
return push(chain, *args, **kwargs)
elif action == "sample":
return sample(chain, *args, **kwargs)
return chain(*args, **kwargs)
return _clone_object
import random
from typing import Callable, List
from ding.worker.buffer.buffer import BufferedData
def group_sample(size_in_group: int, ordered_in_group: bool = True, max_use_in_group: bool = True) -> Callable:
"""
Overview:
The middleware is designed to process the data in each group after sampling from the buffer.
Arguments:
- size_in_group (:obj:`int`): Sample size in each group.
- ordered_in_group (:obj:`bool`): Whether to keep the original order of records, default is true.
- max_use_in_group (:obj:`bool`): Whether to use as much data in each group as possible, default is true.
"""
def sample(chain: Callable, *args, **kwargs) -> List[List[BufferedData]]:
if not kwargs.get("groupby"):
raise Exception("Group sample must be used when the `groupby` parameter is specified.")
sampled_data = chain(*args, **kwargs)
for i, grouped_data in enumerate(sampled_data):
if ordered_in_group:
if max_use_in_group:
end = max(0, len(grouped_data) - size_in_group) + 1
else:
end = len(grouped_data)
start_idx = random.choice(range(end))
sampled_data[i] = grouped_data[start_idx:start_idx + size_in_group]
else:
sampled_data[i] = random.sample(grouped_data, k=size_in_group)
return sampled_data
def _group_sample(action: str, chain: Callable, *args, **kwargs):
if action == "sample":
return sample(chain, *args, **kwargs)
return chain(*args, **kwargs)
return _group_sample
import random
from typing import Callable, Union, List
from ding.worker.buffer import BufferedData
from ding.worker.buffer.utils import fastcopy
def padding(policy="random"):
"""
Overview:
Fill the nested buffer list to the same size as the largest list.
The default policy `random` will randomly select data from each group
and fill it into the current group list.
Arguments:
- policy (:obj:`str`): Padding policy, supports `random`, `none`.
"""
def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]:
sampled_data = chain(*args, **kwargs)
if len(sampled_data) == 0 or isinstance(sampled_data[0], BufferedData):
return sampled_data
max_len = len(max(sampled_data, key=len))
for i, grouped_data in enumerate(sampled_data):
group_len = len(grouped_data)
if group_len == max_len:
continue
for _ in range(max_len - group_len):
if policy == "random":
sampled_data[i].append(fastcopy.copy(random.choice(grouped_data)))
elif policy == "none":
sampled_data[i].append(BufferedData(data=None, index=None, meta=None))
return sampled_data
def _padding(action: str, chain: Callable, *args, **kwargs):
if action == "sample":
return sample(chain, *args, **kwargs)
return chain(*args, **kwargs)
return _padding
from typing import Callable, Any, List, Dict, Optional, Union
import copy
import numpy as np
from ding.utils import SumSegmentTree, MinSegmentTree
from ding.worker.buffer.buffer import BufferedData
class PriorityExperienceReplay:
def __init__(
self,
buffer: 'Buffer', # noqa
buffer_size: int,
IS_weight: bool = True,
priority_power_factor: float = 0.6,
IS_weight_power_factor: float = 0.4,
IS_weight_anneal_train_iter: int = int(1e5),
) -> None:
self.buffer = buffer
self.buffer_idx = {}
self.buffer_size = buffer_size
self.IS_weight = IS_weight
self.priority_power_factor = priority_power_factor
self.IS_weight_power_factor = IS_weight_power_factor
self.IS_weight_anneal_train_iter = IS_weight_anneal_train_iter
# Max priority till now, it's used to initizalize data's priority if "priority" is not passed in with the data.
self.max_priority = 1.0
# Capacity needs to be the power of 2.
capacity = int(np.power(2, np.ceil(np.log2(self.buffer_size))))
self.sum_tree = SumSegmentTree(capacity)
if self.IS_weight:
self.min_tree = MinSegmentTree(capacity)
self.delta_anneal = (1 - self.IS_weight_power_factor) / self.IS_weight_anneal_train_iter
self.pivot = 0
def push(self, chain: Callable, data: Any, meta: Optional[dict] = None, *args, **kwargs) -> BufferedData:
if meta is None:
meta = {'priority': self.max_priority}
else:
if 'priority' not in meta:
meta['priority'] = self.max_priority
meta['priority_idx'] = self.pivot
self._update_tree(meta['priority'], self.pivot)
buffered = chain(data, meta=meta, *args, **kwargs)
index = buffered.index
self.buffer_idx[self.pivot] = index
self.pivot = (self.pivot + 1) % self.buffer_size
return buffered
def sample(self, chain: Callable, size: int, *args,
**kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]:
# Divide [0, 1) into size intervals on average
intervals = np.array([i * 1.0 / size for i in range(size)])
# Uniformly sample within each interval
mass = intervals + np.random.uniform(size=(size, )) * 1. / size
# Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree)
mass *= self.sum_tree.reduce()
indices = [self.sum_tree.find_prefixsum_idx(m) for m in mass]
indices = [self.buffer_idx[i] for i in indices]
# Sample with indices
data = chain(indices=indices, *args, **kwargs)
if self.IS_weight:
# Calculate max weight for normalizing IS
sum_tree_root = self.sum_tree.reduce()
p_min = self.min_tree.reduce() / sum_tree_root
buffer_count = self.buffer.count()
max_weight = (buffer_count * p_min) ** (-self.IS_weight_power_factor)
for i in range(len(data)):
meta = data[i].meta
priority_idx = meta['priority_idx']
p_sample = self.sum_tree[priority_idx] / sum_tree_root
weight = (buffer_count * p_sample) ** (-self.IS_weight_power_factor)
meta['priority_IS'] = weight / max_weight
self.IS_weight_power_factor = min(1.0, self.IS_weight_power_factor + self.delta_anneal)
return data
def update(self, chain: Callable, index: str, data: Any, meta: Any, *args, **kwargs) -> None:
update_flag = chain(index, data, meta, *args, **kwargs)
if update_flag: # when update succeed
assert meta is not None, "Please indicate dict-type meta in priority update"
new_priority, idx = meta['priority'], meta['priority_idx']
assert new_priority >= 0, "new_priority should greater than 0, but found {}".format(new_priority)
new_priority += 1e-5 # Add epsilon to avoid priority == 0
self._update_tree(new_priority, idx)
self.max_priority = max(self.max_priority, new_priority)
def delete(self, chain: Callable, index: str, *args, **kwargs) -> None:
for item in self.buffer.storage:
meta = item.meta
priority_idx = meta['priority_idx']
self.sum_tree[priority_idx] = self.sum_tree.neutral_element
self.min_tree[priority_idx] = self.min_tree.neutral_element
self.buffer_idx.pop(priority_idx)
return chain(index, *args, **kwargs)
def clear(self, chain: Callable) -> None:
self.max_priority = 1.0
capacity = int(np.power(2, np.ceil(np.log2(self.buffer_size))))
self.sum_tree = SumSegmentTree(capacity)
if self.IS_weight:
self.min_tree = MinSegmentTree(capacity)
self.buffer_idx = {}
self.pivot = 0
chain()
def _update_tree(self, priority: float, idx: int) -> None:
weight = priority ** self.priority_power_factor
self.sum_tree[idx] = weight
if self.IS_weight:
self.min_tree[idx] = weight
def state_dict(self) -> Dict:
return {
'max_priority': self.max_priority,
'IS_weight_power_factor': self.IS_weight_power_factor,
'sumtree': self.sumtree,
'mintree': self.mintree,
'buffer_idx': self.buffer_idx,
}
def load_state_dict(self, _state_dict: Dict, deepcopy: bool = False) -> None:
for k, v in _state_dict.items():
if deepcopy:
setattr(self, '{}'.format(k), copy.deepcopy(v))
else:
setattr(self, '{}'.format(k), v)
def __call__(self, action: str, chain: Callable, *args, **kwargs) -> Any:
if action in ["push", "sample", "update", "delete", "clear"]:
return getattr(self, action)(chain, *args, **kwargs)
return chain(*args, **kwargs)
from typing import Callable, Any, List
def staleness_check(buffer_: 'Buffer', max_staleness: int = float("inf")) -> Callable: # noqa
"""
Overview:
This middleware aims to check staleness before each sample operation,
staleness = train_iter_sample_data - train_iter_data_collected, means how old/off-policy the data is,
If data's staleness is greater(>) than max_staleness, this data will be removed from buffer as soon as possible.
"""
def push(next: Callable, data: Any, *args, **kwargs) -> Any:
assert 'meta' in kwargs and 'train_iter_data_collected' in kwargs[
'meta'], "staleness_check middleware must push data with meta={'train_iter_data_collected': <iter>}"
return next(data, *args, **kwargs)
def sample(next: Callable, train_iter_sample_data: int, *args, **kwargs) -> List[Any]:
delete_index = []
for i, item in enumerate(buffer_.storage):
index, meta = item.index, item.meta
staleness = train_iter_sample_data - meta['train_iter_data_collected']
meta['staleness'] = staleness
if staleness > max_staleness:
delete_index.append(index)
for index in delete_index:
buffer_.delete(index)
data = next(*args, **kwargs)
return data
def _staleness_check(action: str, next: Callable, *args, **kwargs) -> Any:
if action == "push":
return push(next, *args, **kwargs)
elif action == "sample":
return sample(next, *args, **kwargs)
return next(*args, **kwargs)
return _staleness_check
from collections import defaultdict
from typing import Callable, Any, List, Optional, Union
from ding.worker.buffer import BufferedData
def use_time_check(buffer_: 'Buffer', max_use: int = float("inf")) -> Callable: # noqa
"""
Overview:
This middleware aims to check the usage times of data in buffer. If the usage times of a data is
greater than or equal to max_use, this data will be removed from buffer as soon as possible.
"""
use_count = defaultdict(int)
def _need_delete(item: BufferedData) -> bool:
nonlocal use_count
idx = item.index
use_count[idx] += 1
item.meta['use_count'] = use_count[idx]
if use_count[idx] >= max_use:
return True
else:
return False
def _check_use_count(sampled_data: List[BufferedData]):
delete_indices = [item.index for item in filter(_need_delete, sampled_data)]
buffer_.delete(delete_indices)
for index in delete_indices:
del use_count[index]
def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]:
sampled_data = chain(*args, **kwargs)
if len(sampled_data) == 0:
return sampled_data
if isinstance(sampled_data[0], BufferedData):
_check_use_count(sampled_data)
else:
for grouped_data in sampled_data:
_check_use_count(grouped_data)
return sampled_data
def _use_time_check(action: str, chain: Callable, *args, **kwargs) -> Any:
if action == "sample":
return sample(chain, *args, **kwargs)
return chain(*args, **kwargs)
return _use_time_check
import pytest
import time
import random
from typing import Callable
from ding.worker.buffer import DequeBuffer
from ding.worker.buffer.buffer import BufferedData
from torch.utils.data import DataLoader
class RateLimit:
r"""
Add rate limit threshold to push function
"""
def __init__(self, max_rate: int = float("inf"), window_seconds: int = 30) -> None:
self.max_rate = max_rate
self.window_seconds = window_seconds
self.buffered = []
def __call__(self, action: str, chain: Callable, *args, **kwargs):
if action == "push":
return self.push(chain, *args, **kwargs)
return chain(*args, **kwargs)
def push(self, chain, data, *args, **kwargs) -> None:
current = time.time()
# Cut off stale records
self.buffered = [t for t in self.buffered if t > current - self.window_seconds]
if len(self.buffered) < self.max_rate:
self.buffered.append(current)
return chain(data, *args, **kwargs)
else:
return None
def add_10() -> Callable:
"""
Transform data on sampling
"""
def sample(chain: Callable, size: int, replace: bool = False, *args, **kwargs):
sampled_data = chain(size, replace, *args, **kwargs)
return [BufferedData(data=item.data + 10, index=item.index, meta=item.meta) for item in sampled_data]
def _subview(action: str, chain: Callable, *args, **kwargs):
if action == "sample":
return sample(chain, *args, **kwargs)
return chain(*args, **kwargs)
return _subview
@pytest.mark.unittest
def test_naive_push_sample():
# Push and sample
buffer = DequeBuffer(size=10)
for i in range(20):
buffer.push(i)
assert buffer.count() == 10
assert 0 not in [item.data for item in buffer.sample(10)]
# Clear
buffer.clear()
assert buffer.count() == 0
# Test replace sample
for i in range(5):
buffer.push(i)
assert buffer.count() == 5
assert len(buffer.sample(10, replace=True)) == 10
# Test slicing
buffer.clear()
for i in range(10):
buffer.push(i)
assert len(buffer.sample(5, sample_range=slice(5, 10))) == 5
assert 0 not in [item.data for item in buffer.sample(5, sample_range=slice(5, 10))]
@pytest.mark.unittest
def test_rate_limit_push_sample():
buffer = DequeBuffer(size=10).use(RateLimit(max_rate=5))
for i in range(10):
buffer.push(i)
assert buffer.count() == 5
assert 5 not in buffer.sample(5)
@pytest.mark.unittest
def test_buffer_view():
buf1 = DequeBuffer(size=10)
for i in range(1):
buf1.push(i)
assert buf1.count() == 1
buf2 = buf1.view().use(RateLimit(max_rate=5)).use(add_10())
for i in range(10):
buf2.push(i)
# With 1 record written by buf1 and 5 records written by buf2
assert len(buf1.middleware) == 0
assert buf1.count() == 6
# All data in buffer should bigger than 10 because of `add_10`
assert all(d.data >= 10 for d in buf2.sample(5))
# But data in storage is still less than 10
assert all(d.data < 10 for d in buf1.sample(5))
@pytest.mark.unittest
def test_sample_with_index():
buf = DequeBuffer(size=10)
for i in range(10):
buf.push({"data": i}, {"meta": i})
# Random sample and get indices
indices = [item.index for item in buf.sample(10)]
assert len(indices) == 10
random.shuffle(indices)
indices = indices[:5]
# Resample by indices
new_indices = [item.index for item in buf.sample(indices=indices)]
assert len(new_indices) == len(indices)
for index in new_indices:
assert index in indices
@pytest.mark.unittest
def test_update():
buf = DequeBuffer(size=10)
for i in range(1):
buf.push({"data": i}, {"meta": i})
# Update one data
[item] = buf.sample(1)
item.data["new_prop"] = "any"
meta = None
success = buf.update(item.index, item.data, item.meta)
assert success
# Resample
[item] = buf.sample(1)
assert "new_prop" in item.data
assert meta is None
# Update object that not exists in buffer
success = buf.update("invalidindex", {}, None)
assert not success
# When exceed buffer size
for i in range(20):
buf.push({"data": i})
assert len(buf.indices) == 10
assert len(buf.storage) == 10
for i in range(10):
index = buf.storage[i].index
assert buf.indices.get(index) == i
@pytest.mark.unittest
def test_delete():
maxlen = 100
cumlen = 40
dellen = 20
buf = DequeBuffer(size=maxlen)
for i in range(cumlen):
buf.push(i)
# Delete data
del_indices = [item.index for item in buf.sample(dellen)]
buf.delete(del_indices)
# Reappend
for i in range(10):
buf.push(i)
remlen = min(cumlen, maxlen) - dellen + 10
assert len(buf.indices) == remlen
assert len(buf.storage) == remlen
for i in range(remlen):
index = buf.storage[i].index
assert buf.indices.get(index) == i
@pytest.mark.unittest
def test_ignore_insufficient():
buffer = DequeBuffer(size=10)
for i in range(2):
buffer.push(i)
with pytest.raises(ValueError):
buffer.sample(3, ignore_insufficient=False)
data = buffer.sample(3, ignore_insufficient=True)
assert len(data) == 0
@pytest.mark.unittest
def test_independence():
# By replace
buffer = DequeBuffer(size=1)
data = {"key": "origin"}
buffer.push(data)
sampled_data = buffer.sample(2, replace=True)
assert len(sampled_data) == 2
sampled_data[0].data["key"] = "new"
assert sampled_data[1].data["key"] == "origin"
# By indices
buffer = DequeBuffer(size=1)
data = {"key": "origin"}
buffered = buffer.push(data)
indices = [buffered.index, buffered.index]
sampled_data = buffer.sample(indices=indices)
assert len(sampled_data) == 2
sampled_data[0].data["key"] = "new"
assert sampled_data[1].data["key"] == "origin"
@pytest.mark.unittest
def test_groupby():
buffer = DequeBuffer(size=3)
buffer.push("a", {"group": 1})
buffer.push("b", {"group": 2})
buffer.push("c", {"group": 2})
sampled_data = buffer.sample(2, groupby="group")
assert len(sampled_data) == 2
group1 = sampled_data[0] if len(sampled_data[0]) == 1 else sampled_data[1]
group2 = sampled_data[0] if len(sampled_data[0]) == 2 else sampled_data[1]
# Group1 should contain a
assert "a" == group1[0].data
# Group2 should contain b and c
data = [buffered.data for buffered in group2] # ["b", "c"]
assert "b" in data
assert "c" in data
# Push new data and swap out a, the result will all in group 2
buffer.push("d", {"group": 2})
sampled_data = buffer.sample(1, groupby="group")
assert len(sampled_data) == 1
assert len(sampled_data[0]) == 3
data = [buffered.data for buffered in sampled_data[0]]
assert "d" in data
# Update meta, set first data's group to 1
first: BufferedData = buffer.storage[0]
buffer.update(first.index, first.data, {"group": 1})
sampled_data = buffer.sample(2, groupby="group")
assert len(sampled_data) == 2
# Delete last record, each group will only have one record
last: BufferedData = buffer.storage[-1]
buffer.delete(last.index)
sampled_data = buffer.sample(2, groupby="group")
assert len(sampled_data) == 2
@pytest.mark.unittest
def test_rolling_window():
buffer = DequeBuffer(size=10)
for i in range(10):
buffer.push(i)
sampled_data = buffer.sample(10, rolling_window=3)
assert len(sampled_data) == 10
# Test data independence
buffer = DequeBuffer(size=2)
for i in range(2):
buffer.push({"key": i})
sampled_data = buffer.sample(2, rolling_window=3)
assert len(sampled_data) == 2
group_long = sampled_data[0] if len(sampled_data[0]) == 2 else sampled_data[1]
group_short = sampled_data[0] if len(sampled_data[0]) == 1 else sampled_data[1]
# Modify the second value
group_long[1].data["key"] = 10
assert group_short[0].data["key"] == 1
@pytest.mark.unittest
def test_import_export():
buffer = DequeBuffer(size=10)
data_with_meta = [(i, {}) for i in range(10)]
buffer.import_data(data_with_meta)
assert buffer.count() == 10
sampled_data = buffer.export_data()
assert len(sampled_data) == 10
@pytest.mark.unittest
def test_dataset():
buffer = DequeBuffer(size=10)
for i in range(10):
buffer.push(i)
dataloader = DataLoader(buffer, batch_size=6, shuffle=True, collate_fn=lambda batch: batch)
for batch in dataloader:
assert len(batch) in [4, 6]
import pytest
import torch
from ding.worker.buffer import DequeBuffer
from ding.worker.buffer.middleware import clone_object, use_time_check, staleness_check
from ding.worker.buffer.middleware import PriorityExperienceReplay, group_sample
from ding.worker.buffer.middleware.padding import padding
@pytest.mark.unittest
def test_clone_object():
buffer = DequeBuffer(size=10).use(clone_object())
# Store a dict, a list, a tensor
arr = [{"key": "v1"}, ["a"], torch.Tensor([1, 2, 3])]
for o in arr:
buffer.push(o)
# Modify it
for item in buffer.sample(len(arr)):
item = item.data
if isinstance(item, dict):
item["key"] = "v2"
elif isinstance(item, list):
item.append("b")
elif isinstance(item, torch.Tensor):
item[0] = 3
else:
raise Exception("Unexpected type")
# Resample it, and check their values
for item in buffer.sample(len(arr)):
item = item.data
if isinstance(item, dict):
assert item["key"] == "v1"
elif isinstance(item, list):
assert len(item) == 1
elif isinstance(item, torch.Tensor):
assert item[0] == 1
else:
raise Exception("Unexpected type")
def get_data():
return {'obs': torch.randn(4), 'reward': torch.randn(1), 'info': 'xxx'}
@pytest.mark.unittest
def test_use_time_check():
N = 6
buffer = DequeBuffer(size=10)
buffer.use(use_time_check(buffer, max_use=2))
for _ in range(N):
buffer.push(get_data())
for _ in range(2):
data = buffer.sample(size=N, replace=False)
assert len(data) == N
with pytest.raises(ValueError):
buffer.sample(size=1, replace=False)
@pytest.mark.unittest
def test_staleness_check():
N = 6
buffer = DequeBuffer(size=10)
buffer.use(staleness_check(buffer, max_staleness=10))
with pytest.raises(AssertionError):
buffer.push(get_data())
for _ in range(N):
buffer.push(get_data(), meta={'train_iter_data_collected': 0})
data = buffer.sample(size=N, replace=False, train_iter_sample_data=9)
assert len(data) == N
data = buffer.sample(size=N, replace=False, train_iter_sample_data=10) # edge case
assert len(data) == N
for _ in range(2):
buffer.push(get_data(), meta={'train_iter_data_collected': 5})
assert buffer.count() == 8
with pytest.raises(ValueError):
data = buffer.sample(size=N, replace=False, train_iter_sample_data=11)
assert buffer.count() == 2
@pytest.mark.unittest
def test_priority():
N = 5
buffer = DequeBuffer(size=10)
buffer.use(PriorityExperienceReplay(buffer, buffer_size=10, IS_weight=True))
for _ in range(N):
buffer.push(get_data())
assert buffer.count() == N
for _ in range(N):
buffer.push(get_data(), meta={'priority': 2.0})
assert buffer.count() == N + N
data = buffer.sample(size=N + N, replace=False)
assert len(data) == N + N
for item in data:
meta = item.meta
assert set(meta.keys()).issuperset(set(['priority', 'priority_idx', 'priority_IS']))
meta['priority'] = 3.0
for item in data:
data, index, meta = item.data, item.index, item.meta
buffer.update(index, data, meta)
data = buffer.sample(size=1)
assert data[0].meta['priority'] == 3.0
buffer.delete(data[0].index)
assert buffer.count() == N + N - 1
buffer.clear()
assert buffer.count() == 0
@pytest.mark.unittest
def test_padding():
buffer = DequeBuffer(size=10)
buffer.use(padding())
for i in range(10):
buffer.push(i, {"group": i & 5}) # [3,3,2,2]
sampled_data = buffer.sample(4, groupby="group")
assert len(sampled_data) == 4
for grouped_data in sampled_data:
assert len(grouped_data) == 3
@pytest.mark.unittest
def test_group_sample():
buffer = DequeBuffer(size=10)
buffer.use(padding(policy="none")).use(group_sample(size_in_group=5, ordered_in_group=True, max_use_in_group=True))
for i in range(4):
buffer.push(i, {"episode": 0})
for i in range(6):
buffer.push(i, {"episode": 1})
sampled_data = buffer.sample(2, groupby="episode")
assert len(sampled_data) == 2
def check_group0(grouped_data):
# In group0 should find only last record with data as None
n_none = 0
for item in grouped_data:
if item.data is None:
n_none += 1
assert n_none == 1
def check_group1(grouped_data):
# In group1 every record should have data and meta
for item in grouped_data:
assert item.data is not None
for grouped_data in sampled_data:
assert len(grouped_data) == 5
meta = grouped_data[0].meta
if meta and "episode" in meta and meta["episode"] == 1:
check_group1(grouped_data)
else:
check_group0(grouped_data)
from .fast_copy import FastCopy, fastcopy
import torch
import numpy as np
from typing import Any, List
from ding.worker.buffer.buffer import BufferedData
class FastCopy:
"""
The idea of this class comes from this article
https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list.
We use recursive calls to copy each object that needs to be copied, which will be 5x faster
than copy.deepcopy.
"""
def __init__(self):
dispatch = {}
dispatch[list] = self._copy_list
dispatch[dict] = self._copy_dict
dispatch[torch.Tensor] = self._copy_tensor
dispatch[np.ndarray] = self._copy_ndarray
dispatch[BufferedData] = self._copy_buffereddata
self.dispatch = dispatch
def _copy_list(self, l: List) -> dict:
ret = l.copy()
for idx, item in enumerate(ret):
cp = self.dispatch.get(type(item))
if cp is not None:
ret[idx] = cp(item)
return ret
def _copy_dict(self, d: dict) -> dict:
ret = d.copy()
for key, value in ret.items():
cp = self.dispatch.get(type(value))
if cp is not None:
ret[key] = cp(value)
return ret
def _copy_tensor(self, t: torch.Tensor) -> torch.Tensor:
return t.clone()
def _copy_ndarray(self, a: np.ndarray) -> np.ndarray:
return np.copy(a)
def _copy_buffereddata(self, d: BufferedData) -> BufferedData:
return BufferedData(data=self.copy(d.data), index=d.index, meta=self.copy(d.meta))
def copy(self, sth: Any) -> Any:
cp = self.dispatch.get(type(sth))
if cp is None:
return sth
else:
return cp(sth)
fastcopy = FastCopy()
......@@ -51,6 +51,7 @@ pong_dqn_create_config = dict(
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
# replay_buffer=dict(type='deque'),
)
pong_dqn_create_config = EasyDict(pong_dqn_create_config)
create_config = pong_dqn_create_config
......
from easydict import EasyDict
from ding.entry import serial_pipeline
nstep = 3
lunarlander_dqn_default_config = dict(
exp_name='lunarlander_dqn_priority',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, ),
# Env number respectively for collector and evaluator.
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=200,
),
policy=dict(
# Whether to use cuda for network.
cuda=False,
priority=True,
priority_IS_weight=False,
model=dict(
obs_shape=8,
action_shape=4,
encoder_hidden_size_list=[512, 64],
# Whether to use dueling head.
dueling=True,
),
# Reward's future discount factor, aka. gamma.
discount_factor=0.99,
# How many steps in td error.
nstep=nstep,
# learn_mode config
learn=dict(
update_per_collect=10,
batch_size=64,
learning_rate=0.001,
# Frequency of target network update.
target_update_freq=100,
),
# collect_mode config
collect=dict(
# You can use either "n_sample" or "n_episode" in collector.collect.
# Get "n_sample" samples per collect.
n_sample=64,
# Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
),
# command_mode config
other=dict(
# Epsilon greedy with decay.
eps=dict(
# Decay type. Support ['exp', 'linear'].
type='exp',
start=0.95,
end=0.1,
decay=50000,
),
replay_buffer=dict(replay_buffer_size=100000, priority=True, priority_IS_weight=False)
),
),
)
lunarlander_dqn_default_config = EasyDict(lunarlander_dqn_default_config)
main_config = lunarlander_dqn_default_config
lunarlander_dqn_create_config = dict(
env=dict(
type='lunarlander',
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
replay_buffer=dict(type='deque'),
)
lunarlander_dqn_create_config = EasyDict(lunarlander_dqn_create_config)
create_config = lunarlander_dqn_create_config
if __name__ == "__main__":
serial_pipeline([main_config, create_config], seed=0)
import os
import gym
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, DequeBufferWrapper
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import DQNPolicy
from ding.model import DQN
from ding.utils import set_pkg_seed
from ding.rl_utils import get_epsilon_greedy_fn
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config
# Get DI-engine form env class
def wrapped_cartpole_env():
return DingEnvWrapper(gym.make('CartPole-v0'))
def main(cfg, seed=0):
cfg = compile_config(
cfg,
BaseEnvManager,
DQNPolicy,
BaseLearner,
SampleSerialCollector,
InteractionSerialEvaluator,
DequeBufferWrapper,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(collector_env_num)], cfg=cfg.env.manager)
evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
# Set random seed for all package and instance
collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
# Set up RL Policy
model = DQN(**cfg.policy.model)
policy = DQNPolicy(cfg.policy, model=model)
# Set up collection, training and evaluation utilities
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = DequeBufferWrapper(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
# Set up other modules, etc. epsilon greedy
eps_cfg = cfg.policy.other.eps
epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
# Training & Evaluation loop
while True:
# Evaluating at the beginning and with specific frequency
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Update other modules
eps = epsilon_greedy(collector.envstep)
# Sampling data from environments
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# Training
for i in range(cfg.policy.learn.update_per_collect):
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is None:
break
learner.train(train_data, collector.envstep)
if __name__ == "__main__":
main(cartpole_dqn_config)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册