提交 dd6205b2 编写于 作者: N niuyazhe

fix(nyz): fix parallel test exp_name bug

上级 79587654
......@@ -153,7 +153,7 @@ def save_config_py(config_: dict, path: str) -> NoReturn:
config_string, _ = FormatCode(config_string)
config_string = config_string.replace('inf', 'float("inf")')
with open(path, "w") as f:
f.write('exp_config=' + config_string)
f.write('exp_config = ' + config_string)
def read_config_directly(path: str) -> dict:
......
......@@ -24,7 +24,7 @@ def test_config_formatted(config_path, name):
main_config, seed=0, auto=True, create_cfg=create_config, save_cfg=True, save_path='{}_config.py'.format(name)
)
module = importlib.import_module('formatted_{}_config'.format(name))
module = importlib.import_module('cartpole_{}.formatted_{}_config'.format(name, name))
main_config, create_config = module.main_config, module.create_config
cfg_test = compile_config(main_config, seed=0, auto=True, create_cfg=create_config, save_cfg=False)
assert cfg == cfg_test, 'cfg_formatted_failed'
......@@ -251,6 +251,7 @@ def save_config_formatted(config_: dict, path: str = 'formatted_total_config.py'
with open(path, "w") as f:
f.write('from easydict import EasyDict\n\n')
f.write('main_config = dict(\n')
f.write(" exp_name='{}',\n".format(config_.exp_name))
for k, v in config_.items():
if (k == 'env'):
f.write(' env=dict(\n')
......
......@@ -2,6 +2,7 @@ from easydict import EasyDict
from ding.config import parallel_transform
fake_cpong_dqn_config = dict(
exp_name='fake_cpong_dqn',
env=dict(
collector_env_num=16,
collector_episode_num=2,
......
......@@ -61,6 +61,7 @@ class NaiveCommander(BaseCommander):
"collector_task_space" and "learner_task_space".
"""
self._cfg = cfg
self._exp_name = cfg.exp_name
commander_cfg = self._cfg.policy.other.commander
self._collector_task_space = LimitedSpaceContainer(0, commander_cfg.collector_task_space)
self._learner_task_space = LimitedSpaceContainer(0, commander_cfg.learner_task_space)
......@@ -91,6 +92,7 @@ class NaiveCommander(BaseCommander):
collector_cfg.policy = copy.deepcopy(self._cfg.policy)
collector_cfg.policy_update_path = 'test.pth'
collector_cfg.env = self._collector_env_cfg
collector_cfg.exp_name = self._exp_name
return {
'task_id': 'collector_task_id{}'.format(self._collector_task_count),
'buffer_id': 'test',
......@@ -109,6 +111,7 @@ class NaiveCommander(BaseCommander):
if self._learner_task_space.acquire_space():
self._learner_task_count += 1
learner_cfg = copy.deepcopy(self._cfg.policy.learn.learner)
learner_cfg.exp_name = self._exp_name
return {
'task_id': 'learner_task_id{}'.format(self._learner_task_count),
'policy_id': 'test.pth',
......
......@@ -10,6 +10,7 @@ def setup_1v1commander():
nstep = 1
eval_interval = 5
main_config = dict(
exp_name='one_vs_one_test',
env=dict(
collector_env_num=8,
collector_episode_num=2,
......
......@@ -247,8 +247,16 @@ class TestDemonstrationBuffer:
def test_naive(self, setup_demo_buffer_factory):
setup_demo_buffer = next(setup_demo_buffer_factory)
naive_demo_buffer = next(setup_demo_buffer_factory)
with open(demo_data_path, 'rb+') as f:
data = pickle.load(f)
while True:
with open(demo_data_path, 'rb+') as f:
data = pickle.load(f)
if len(data) != 0:
break
else: # for the stability of dist-test
demo_data = {'data': generate_data_list(10)}
with open(demo_data_path, "wb") as f:
pickle.dump(demo_data, f)
setup_demo_buffer.load_state_dict(data)
assert setup_demo_buffer.count() == len(data['data']) # assert buffer not empty
samples = setup_demo_buffer.sample(3, 0)
......
from easydict import EasyDict
cartpole_a2c_config = dict(
exp_name='cartpole_a2c',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
......
from easydict import EasyDict
cartpole_c51_config = dict(
exp_name='cartpole_c51',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册