未验证 提交 00234383 编写于 作者: L LuciusMos 提交者: GitHub

feature(zlx): add tb in naive buffer; modify tb in advanced buffer (#39)

* feature(zlx): Add tb in naive buffer; modify tb in advanced buffer

* feature(zlx): naive_buffer tb, fix bug in valid_count update
上级 cd401b92
......@@ -97,7 +97,8 @@ class AdvancedReplayBuffer(IBuffer):
Arguments:
- cfg (:obj:`dict`): Config dict.
- tb_logger (:obj:`Optional['SummaryWriter']`): Outer tb logger. Usually get this argument in serial mode.
- name (:obj:`Optional[str]`): Buffer name, used to generate unique data id and logger name.
- exp_name (:obj:`Optional[str]`): Name of this experiment.
- instance_name (:obj:`Optional[str]`): Name of this instance.
"""
self._exp_name = exp_name
self._instance_name = instance_name
......@@ -615,10 +616,6 @@ class AdvancedReplayBuffer(IBuffer):
self._periodic_thruput_monitor.push_data_count += add_count
if self._use_thruput_controller:
self._thruput_controller.history_push_count += add_count
self._tb_logger.add_scalar(
'buffer_{}_sec/'.format(self._instance_name) + 'push', add_count,
time.time() - self._start_time
)
self._cur_collector_envstep = cur_collector_envstep
def _monitor_update_of_sample(self, sample_data: list, cur_learner_iter: int) -> None:
......@@ -670,10 +667,6 @@ class AdvancedReplayBuffer(IBuffer):
self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, iter_metric)
if step_metric is not None:
self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, step_metric)
self._tb_logger.add_scalar(
'{}_sec/'.format(self._instance_name) + 'sample', len(sample_data),
time.time() - self._start_time
)
self._sampled_data_attr_print_count += 1
def _calculate_staleness(self, pos_index: int, cur_learner_iter: int) -> Optional[int]:
......
import copy
from typing import Union, Any, Optional, List
import numpy as np
from easydict import EasyDict
from ding.worker.replay_buffer import IBuffer
from ding.utils import LockContext, LockContextType, BUFFER_REGISTRY
from .utils import UsedDataRemover
from ding.utils import LockContext, LockContextType, BUFFER_REGISTRY, build_logger
from .utils import UsedDataRemover, PeriodicThruputMonitor
@BUFFER_REGISTRY.register('naive')
......@@ -32,7 +33,7 @@ class NaiveReplayBuffer(IBuffer):
def __init__(
self,
cfg: 'EasyDict', # noqa
name: str = 'default',
tb_logger: Optional['SummaryWriter'] = None, # noqa
exp_name: Optional[str] = 'default_experiment',
instance_name: Optional[str] = 'buffer',
) -> None:
......@@ -41,7 +42,9 @@ class NaiveReplayBuffer(IBuffer):
Initialize the buffer
Arguments:
- cfg (:obj:`dict`): Config dict.
- name (:obj:`Optional[str]`): Buffer name, used to generate unique data id and logger name.
- tb_logger (:obj:`Optional['SummaryWriter']`): Outer tb logger. Usually get this argument in serial mode.
- exp_name (:obj:`Optional[str]`): Name of this experiment.
- instance_name (:obj:`Optional[str]`): Name of this instance.
"""
self._exp_name = exp_name
self._instance_name = instance_name
......@@ -62,6 +65,20 @@ class NaiveReplayBuffer(IBuffer):
self._enable_track_used_data = self._cfg.enable_track_used_data
if self._enable_track_used_data:
self._used_data_remover = UsedDataRemover()
if tb_logger is not None:
self._logger, _ = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
)
self._tb_logger = tb_logger
else:
self._logger, self._tb_logger = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name),
self._instance_name,
)
# Periodic thruput. Here by default, monitor range is 60 seconds. You can modify it for free.
self._periodic_thruput_monitor = PeriodicThruputMonitor(
self._instance_name, EasyDict(seconds=3), self._logger, self._tb_logger
)
def start(self) -> None:
"""
......@@ -79,6 +96,8 @@ class NaiveReplayBuffer(IBuffer):
self.clear()
if self._enable_track_used_data:
self._used_data_remover.close()
self._tb_logger.flush()
self._tb_logger.close()
def push(self, data: Union[List[Any], Any], cur_collector_envstep: int) -> None:
r"""
......@@ -92,8 +111,10 @@ class NaiveReplayBuffer(IBuffer):
"""
if isinstance(data, list):
self._extend(data, cur_collector_envstep)
self._periodic_thruput_monitor.push_data_count += len(data)
else:
self._append(data, cur_collector_envstep)
self._periodic_thruput_monitor.push_data_count += 1
def sample(self, size: int, cur_learner_iter: int, sample_range: slice = None) -> Optional[list]:
"""
......@@ -115,8 +136,9 @@ class NaiveReplayBuffer(IBuffer):
return None
with self._lock:
indices = self._get_indices(size, sample_range)
result = self._sample_with_indices(indices, cur_learner_iter)
return result
sample_data = self._sample_with_indices(indices, cur_learner_iter)
self._periodic_thruput_monitor.sample_data_count += len(sample_data)
return sample_data
def _append(self, ori_data: Any, cur_collector_envstep: int = -1) -> None:
r"""
......@@ -134,6 +156,7 @@ class NaiveReplayBuffer(IBuffer):
self._push_count += 1
if self._data[self._tail] is None:
self._valid_count += 1
self._periodic_thruput_monitor.valid_count = self._valid_count
elif self._enable_track_used_data:
self._used_data_remover.add_used_data(self._data[self._tail])
self._data[self._tail] = data
......@@ -160,6 +183,7 @@ class NaiveReplayBuffer(IBuffer):
if self._tail + length <= self._replay_buffer_size:
if self._valid_count != self._replay_buffer_size:
self._valid_count += length
self._periodic_thruput_monitor.valid_count = self._valid_count
elif self._enable_track_used_data:
for i in range(length):
self._used_data_remover.add_used_data(self._data[self._tail + i])
......@@ -174,6 +198,7 @@ class NaiveReplayBuffer(IBuffer):
L = min(space, residual_num)
if self._valid_count != self._replay_buffer_size:
self._valid_count += L
self._periodic_thruput_monitor.valid_count = self._valid_count
elif self._enable_track_used_data:
for i in range(L):
self._used_data_remover.add_used_data(self._data[new_tail + i])
......@@ -226,6 +251,7 @@ class NaiveReplayBuffer(IBuffer):
self._used_data_remover.add_used_data(self._data[i])
self._data[i] = None
self._valid_count = 0
self._periodic_thruput_monitor.valid_count = self._valid_count
self._push_count = 0
self._tail = 0
......
......@@ -125,7 +125,7 @@ class PeriodicThruputMonitor:
"""
Overview:
PeriodicThruputMonitor is a tool to record and print logs(text & tensorboard) how many datas are
pushed/sampled/removed/valid in a period of time.
pushed/sampled/removed/valid in a period of time. For tensorboard, you can view it in 'buffer_{$NAME}_sec'.
Interface:
close
Property:
......@@ -165,7 +165,7 @@ class PeriodicThruputMonitor:
}
self._logger.info(self._logger.get_tabulate_vars_hor(count_dict))
for k, v in count_dict.items():
self._tb_logger.add_scalar('buffer_{}_sec/'.format(self.name) + k, v, self._thruput_print_times)
self._tb_logger.add_scalar('{}_sec/'.format(self.name) + k, v, self._thruput_print_times)
self._history_push_count = 0
self._history_sample_count = 0
self._remove_data_count = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册