提交 fed80b44 编写于 作者: N niuyazhe

polish(nyz): polish collector benchmark test(enable docker, smac docker)

上级 0414eda5
......@@ -14,9 +14,19 @@ from ding.worker.collector.tests.speed_test.utils import random_change
class FakePolicy(Policy):
def default_config(cls: type) -> EasyDict:
return EasyDict({})
config = dict(
cuda=False,
on_policy=False,
forward_time=0.002,
learn=dict(),
collect=dict(
n_sample=80,
unroll_len=1,
collector=dict(collect_print_freq=1000000),
),
eval=dict(),
other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ),
)
def __init__(
self,
......@@ -25,7 +35,7 @@ class FakePolicy(Policy):
enable_field: Optional[List[str]] = None
) -> None:
self._cfg = cfg
self._use_cuda = cfg.use_cuda and torch.cuda.is_available()
self._cuda = cfg.cuda and torch.cuda.is_available()
self._init_collect()
self._forward_time = cfg.forward_time
self._on_policy = cfg.on_policy
......
import time
import logging
from easydict import EasyDict
import time
import copy
import pytest
import numpy as np
from easydict import EasyDict
from functools import partial
import copy
from ding.worker import SampleSerialCollector, NaiveReplayBuffer
from ding.envs import get_vec_env_setting, create_env_manager, AsyncSubprocessEnvManager, SyncSubprocessEnvManager,\
BaseEnvManager
from ding.utils import deep_merge_dicts, set_pkg_seed
BaseEnvManager, get_env_manager_cls
from ding.utils import deep_merge_dicts, set_pkg_seed, pretty_print
from ding.worker.collector.tests.speed_test.fake_policy import FakePolicy
from ding.worker.collector.tests.speed_test.fake_env import FakeEnv, env_sum
from ding.worker.collector.tests.speed_test.test_config import test_config
from ding.worker.collector.tests.speed_test.fake_env import FakeEnv
# SLOW MODE:
# - Repeat 3 times; Collect 300 iterations;
env_policy_cfg_dict = dict(
# Small env and policy, such as Atari/Mujoco
small=dict(
size="small",
env=dict(
collector_env_num=8,
obs_dim=64,
action_dim=2,
episode_step=500,
reset_time=0.1,
step_time=0.005,
manager=dict(),
),
policy=dict(forward_time=0.004),
),
# Middle env and policy, such as Carla/Sumo/Vizdoom
middle=dict(
size="middle",
env=dict(
collector_env_num=8,
obs_dim=int(3e2), # int(3e3),
action_dim=2,
episode_step=500,
reset_time=0.5,
step_time=0.01,
manager=dict(),
),
policy=dict(forward_time=0.008),
),
# Big env and policy, such as SC2 full game
big=dict(
size="big",
env=dict(
collector_env_num=8,
obs_dim=int(3e3), # int(3e6),
action_dim=2,
episode_step=500,
reset_time=2,
step_time=0.1,
manager=dict(),
),
policy=dict(forward_time=0.02)
),
)
# SLOW MODE: used in normal test
# - Repeat 3 times; Collect 300 times;
# - Test on small + middle + big env
# - Test on base + asynnc_subprocess + sync_subprocess env manager
# - Test on base + async_subprocess + sync_subprocess env manager
# - Test with reset_ratio = 1 and 5.
# FAST MODE:
# - Only once (No repeat); Collect 50 iterations;
# FAST MODE: used in CI benchmark test
# - Only once (No repeat); Collect 50 times;
# - Test on small env
# - Test on sync_subprocess env manager
# - Test on base + sync_subprocess env manager
# - Test with reset_ratio = 1.
FAST_MODE = True
if FAST_MODE:
# Note: 'base' takes approximately 6 times longer than 'subprocess'
test_env_manager_list = ['base', 'subprocess']
test_env_policy_cfg_dict = {'small': env_policy_cfg_dict['small']}
env_reset_ratio_list = [1]
repeat_times_per_test = 1
collect_times_per_repeat = 50
n_sample = 80
else:
test_env_manager_list = ['base', 'subprocess', 'sync_subprocess']
test_env_policy_cfg_dict = env_policy_cfg_dict
env_reset_ratio_list = [1, 5]
repeat_times_per_test = 3
collect_times_per_repeat = 300
n_sample = 80
def compare_test(cfg: EasyDict, seed: int, test_name: str) -> None:
print('=' * 100 + '\nTest Name: {}\nCfg:'.format(test_name))
pretty_print(cfg)
def compare_test(cfg, out_str, seed):
global FAST_MODE
duration_list = []
repeat_times = 1 if FAST_MODE else 3
for i in range(repeat_times):
env_fn = FakeEnv
total_collected_sample = n_sample * collect_times_per_repeat
for i in range(repeat_times_per_test):
# create collector_env
collector_env_cfg = copy.deepcopy(cfg.env)
collector_env_num = collector_env_cfg.pop('collector_env_num')
collector_env_cfg.pop('manager')
collector_env_fns = [partial(env_fn, cfg=collector_env_cfg) for _ in range(collector_env_num)]
if cfg.env.manager.type == 'base':
env_manager_type = BaseEnvManager
elif cfg.env.manager.type == 'async_subprocess':
env_manager_type = AsyncSubprocessEnvManager
elif cfg.env.manager.type == 'subprocess':
env_manager_type = SyncSubprocessEnvManager
env_manager_cfg = deep_merge_dicts(env_manager_type.default_config(), cfg.env.manager)
collector_env = env_manager_type(collector_env_fns, env_manager_cfg)
collector_env.seed(seed)
collector_env_num = collector_env_cfg.collector_env_num
collector_env_fns = [partial(FakeEnv, cfg=collector_env_cfg) for _ in range(collector_env_num)]
# cfg.policy.collect.collector = deep_merge_dicts(
# SampleSerialCollector.default_config(), cfg.policy.collect.collector)
collector_env = create_env_manager(cfg.env.manager, collector_env_fns)
collector_env.seed(seed)
# create policy
policy = FakePolicy(cfg.policy)
collector_cfg = deep_merge_dicts(SampleSerialCollector.default_config(), cfg.policy.collect.collector)
collector = SampleSerialCollector(collector_cfg, collector_env, policy.collect_mode)
buffer_cfg = deep_merge_dicts(cfg.policy.other.replay_buffer, NaiveReplayBuffer.default_config())
replay_buffer = NaiveReplayBuffer(buffer_cfg)
start = time.time()
iters = 50 if FAST_MODE else 300
for iter in range(iters):
if iter % 50 == 0:
print('\t', iter)
new_data = collector.collect(train_iter=iter)
replay_buffer.push(new_data, cur_collector_envstep=iter * 8)
duration_list.append(time.time() - start)
print('\tduration: {}'.format(time.time() - start))
# create collector and buffer
collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode)
replay_buffer = NaiveReplayBuffer(cfg.policy.other.replay_buffer)
# collect test
t1 = time.time()
for i in range(collect_times_per_repeat):
new_data = collector.collect()
assert len(new_data) == n_sample
replay_buffer.push(new_data, cur_collector_envstep=i * n_sample)
duration_list.append(time.time() - t1)
# close and release
collector.close()
replay_buffer.close()
del policy
del collector
del replay_buffer
print('avg duration: {}; ({})'.format(sum(duration_list) / len(duration_list), duration_list))
out_str.append('avg duration: {}; ({})'.format(sum(duration_list) / len(duration_list), duration_list))
fps = [total_collected_sample / duration for duration in duration_list]
print('\nTest Result:\nAvg FPS(env frame per second): {:.3f}±{:.3f} frame/s'.format(np.mean(fps), np.std(fps)))
print('=' * 100)
@pytest.mark.benchmark
def test_collector_profile():
global FAST_MODE
# ignore them for clear log
collector_log = logging.getLogger('collector_logger')
collector_log.disabled = True
buffer_log = logging.getLogger('agent_buffer_logger')
buffer_log = logging.getLogger('buffer_logger')
buffer_log.disabled = True
seed = 0
set_pkg_seed(seed, use_cuda=False)
cfgs = [
dict(
size="small",
env=dict(env_kwargs=dict(
obs_dim=64,
action_dim=2,
episode_step=500,
reset_time=0.1,
step_time=0.005,
), ),
policy=dict(forward_time=0.004, on_policy=False),
collector=dict(n_sample=80, ),
),
dict(
size="middle",
env=dict(
env_kwargs=dict(
obs_dim=int(3e2), # int(3e3),
action_dim=2,
episode_step=500,
reset_time=0.5,
step_time=0.01,
),
),
policy=dict(forward_time=0.008, on_policy=False),
collector=dict(n_sample=80, ),
),
for cfg_name, env_policy_cfg in test_env_policy_cfg_dict.items():
for env_manager_type in test_env_manager_list:
for env_reset_ratio in env_reset_ratio_list:
# Big env(45min) takes much longer time than small(5min) and middle(10min).
dict(
size="big",
env=dict(
env_kwargs=dict(
obs_dim=int(3e3), # int(3e6),
action_dim=2,
episode_step=500,
reset_time=2,
step_time=0.1,
),
),
policy=dict(forward_time=0.02, on_policy=False),
collector=dict(n_sample=80, ),
),
]
out_str = []
if FAST_MODE:
cfgs.pop(-1)
cfgs.pop(-1)
for cfg in cfgs:
# Note: 'base' takes approximately 6 times longer than 'subprocess'
if FAST_MODE:
envm_list = ['subprocess']
else:
envm_list = ['base', 'async_subprocess', 'subprocess']
for envm in envm_list:
if FAST_MODE:
reset_list = [1]
else:
reset_list = [1, 5] # [1, 5]
for reset_ratio in reset_list:
copy_cfg = copy.deepcopy(cfg)
copy_test_config = copy.deepcopy(test_config)
copy_cfg = EasyDict(copy_cfg)
copy_cfg = deep_merge_dicts(copy_test_config, copy_cfg)
copy_cfg.env.reset_time *= reset_ratio
copy_cfg.env.manager.type = envm
if copy_cfg.env.manager.type == 'base':
copy_cfg.env.manager.pop('step_wait_timeout')
copy_cfg.env.manager.pop('wait_num')
print('=={}, {}, reset x{}'.format(copy_cfg.size, envm, reset_ratio))
print(copy_cfg)
out_str.append('=={}, {}, reset x{}'.format(copy_cfg.size, envm, reset_ratio))
compare_test(copy_cfg, out_str, seed)
print('\n'.join(out_str))
test_name = '{}-{}-reset{}'.format(cfg_name, env_manager_type, env_reset_ratio)
copy_cfg = EasyDict(copy.deepcopy(env_policy_cfg))
env_manager_cfg = EasyDict({'type': env_manager_type})
# modify args inplace
copy_cfg.policy = deep_merge_dicts(FakePolicy.default_config(), copy_cfg.policy)
copy_cfg.policy.collect.collector = deep_merge_dicts(
SampleSerialCollector.default_config(), copy_cfg.policy.collect.collector
)
copy_cfg.policy.collect.collector.n_sample = n_sample
copy_cfg.policy.other.replay_buffer = deep_merge_dicts(
NaiveReplayBuffer.default_config(), copy_cfg.policy.other.replay_buffer
)
copy_cfg.env.reset_time *= env_reset_ratio
copy_cfg.env.manager = get_env_manager_cls(env_manager_cfg).default_config()
copy_cfg.env.manager.type = env_manager_type
compare_test(copy_cfg, seed, test_name)
from easydict import EasyDict
nstep = 1
test_config = dict(
env=dict(
manager=dict(
type='async_subprocess',
wait_num=7, # 8-1
step_wait_timeout=0.01,
),
collector_env_num=8,
evaluator_env_num=5,
obs_dim=8,
action_dim=2,
episode_step=200,
reset_time=0.01,
step_time=0.003,
),
policy=dict(
use_cuda=False,
forward_time=0.002,
learn=dict(),
collect=dict(
n_sample=80,
unroll_len=1,
collector=dict(),
),
eval=dict(),
other=dict(replay_buffer=dict(
type='naive',
replay_buffer_size=10000,
), ),
),
)
test_config = EasyDict(test_config)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册