提交 ae6ab6c7 编写于 作者: N niuyazhe

fix(nyz): fix exp_name seedx name bug with data generation path

上级 35241df3
...@@ -154,6 +154,7 @@ def collect_demo_data( ...@@ -154,6 +154,7 @@ def collect_demo_data(
if cfg.policy.cuda: if cfg.policy.cuda:
exp_data = to_device(exp_data, 'cpu') exp_data = to_device(exp_data, 'cpu')
# Save data transitions. # Save data transitions.
expert_data_path = os.path.join(cfg.exp_name, expert_data_path)
offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive')) offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive'))
print('Collect demo data successfully') print('Collect demo data successfully')
...@@ -227,6 +228,7 @@ def collect_episodic_demo_data( ...@@ -227,6 +228,7 @@ def collect_episodic_demo_data(
if cfg.policy.cuda: if cfg.policy.cuda:
exp_data = to_device(exp_data, 'cpu') exp_data = to_device(exp_data, 'cpu')
# Save data transitions. # Save data transitions.
expert_data_path = os.path.join(cfg.exp_name, expert_data_path)
offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive')) offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive'))
print('Collect episodic demo data successfully') print('Collect episodic demo data successfully')
......
...@@ -2,6 +2,7 @@ import pytest ...@@ -2,6 +2,7 @@ import pytest
import time import time
import os import os
from copy import deepcopy from copy import deepcopy
import torch
from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_offline from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_offline
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
...@@ -360,7 +361,9 @@ def test_sqn(): ...@@ -360,7 +361,9 @@ def test_sqn():
@pytest.mark.unittest @pytest.mark.unittest
def test_selfplay(): def test_selfplay():
try: try:
selfplay_main(deepcopy(league_demo_ppo_config), seed=0, max_iterations=1) config = deepcopy(league_demo_ppo_config)
config.exp_name = 'test_selfplay'
selfplay_main(config, seed=0, max_iterations=1)
except Exception: except Exception:
assert False, "pipeline fail" assert False, "pipeline fail"
...@@ -368,7 +371,9 @@ def test_selfplay(): ...@@ -368,7 +371,9 @@ def test_selfplay():
@pytest.mark.unittest @pytest.mark.unittest
def test_league(): def test_league():
try: try:
league_main(deepcopy(league_demo_ppo_config), seed=0, max_iterations=1) config = deepcopy(league_demo_ppo_config)
config.exp_name = 'test_league'
league_main(config, seed=0, max_iterations=1)
except Exception as e: except Exception as e:
assert False, "pipeline fail" assert False, "pipeline fail"
...@@ -395,14 +400,13 @@ def test_cql(): ...@@ -395,14 +400,13 @@ def test_cql():
assert False, "pipeline fail" assert False, "pipeline fail"
# collect expert data # collect expert data
import torch
config = [ config = [
deepcopy(pendulum_sac_data_genearation_default_config), deepcopy(pendulum_sac_data_genearation_default_config),
deepcopy(pendulum_sac_data_genearation_default_create_config) deepcopy(pendulum_sac_data_genearation_default_create_config)
] ]
collect_count = 1000 collect_count = 1000
expert_data_path = config[0].policy.collect.save_path expert_data_path = config[0].policy.collect.save_path
state_dict = torch.load('./sac/ckpt/iteration_0.pth.tar', map_location='cpu') state_dict = torch.load('./sac_seed0/ckpt/iteration_0.pth.tar', map_location='cpu')
try: try:
collect_demo_data( collect_demo_data(
config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict
...@@ -442,11 +446,10 @@ def test_discrete_cql(): ...@@ -442,11 +446,10 @@ def test_discrete_cql():
except Exception: except Exception:
assert False, "pipeline fail" assert False, "pipeline fail"
# collect expert data # collect expert data
import torch
config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)] config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)]
collect_count = 1000 collect_count = 1000
expert_data_path = config[0].policy.collect.save_path expert_data_path = config[0].policy.collect.save_path
state_dict = torch.load('./cql_cartpole/ckpt/iteration_0.pth.tar', map_location='cpu') state_dict = torch.load('./cql_cartpole_seed0/ckpt/iteration_0.pth.tar', map_location='cpu')
try: try:
collect_demo_data( collect_demo_data(
config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict
...@@ -467,7 +470,7 @@ def test_discrete_cql(): ...@@ -467,7 +470,7 @@ def test_discrete_cql():
os.popen('rm -rf cartpole cartpole_cql') os.popen('rm -rf cartpole cartpole_cql')
@pytest.mark.algotest @pytest.mark.unittest
def test_td3_bc(): def test_td3_bc():
# train expert # train expert
config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)] config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)]
...@@ -479,11 +482,10 @@ def test_td3_bc(): ...@@ -479,11 +482,10 @@ def test_td3_bc():
assert False, "pipeline fail" assert False, "pipeline fail"
# collect expert data # collect expert data
import torch
config = [deepcopy(pendulum_td3_generation_config), deepcopy(pendulum_td3_generation_create_config)] config = [deepcopy(pendulum_td3_generation_config), deepcopy(pendulum_td3_generation_create_config)]
collect_count = 1000 collect_count = 1000
expert_data_path = config[0].policy.collect.save_path expert_data_path = config[0].policy.collect.save_path
state_dict = torch.load('./td3/ckpt/iteration_0.pth.tar', map_location='cpu') state_dict = torch.load('./td3_seed0/ckpt/iteration_0.pth.tar', map_location='cpu')
try: try:
collect_demo_data( collect_demo_data(
config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict
......
...@@ -281,7 +281,9 @@ def test_acer(): ...@@ -281,7 +281,9 @@ def test_acer():
@pytest.mark.algotest @pytest.mark.algotest
def test_selfplay(): def test_selfplay():
try: try:
selfplay_main(deepcopy(league_demo_ppo_config), seed=0) config = deepcopy(league_demo_ppo_config)
config.exp_name = 'test_selfplay'
selfplay_main(config, seed=0)
except Exception: except Exception:
assert False, "pipeline fail" assert False, "pipeline fail"
with open("./algo_record.log", "a+") as f: with open("./algo_record.log", "a+") as f:
...@@ -291,7 +293,9 @@ def test_selfplay(): ...@@ -291,7 +293,9 @@ def test_selfplay():
@pytest.mark.algotest @pytest.mark.algotest
def test_league(): def test_league():
try: try:
league_main(deepcopy(league_demo_ppo_config), seed=0) config = deepcopy(league_demo_ppo_config)
config.exp_name = 'test_league'
league_main(config, seed=0)
except Exception: except Exception:
assert False, "pipeline fail" assert False, "pipeline fail"
with open("./algo_record.log", "a+") as f: with open("./algo_record.log", "a+") as f:
...@@ -326,14 +330,13 @@ def test_cql(): ...@@ -326,14 +330,13 @@ def test_cql():
assert False, "pipeline fail" assert False, "pipeline fail"
# collect expert data # collect expert data
import torch
config = [ config = [
deepcopy(pendulum_sac_data_genearation_default_config), deepcopy(pendulum_sac_data_genearation_default_config),
deepcopy(pendulum_sac_data_genearation_default_create_config) deepcopy(pendulum_sac_data_genearation_default_create_config)
] ]
collect_count = config[0].policy.other.replay_buffer.replay_buffer_size collect_count = config[0].policy.other.replay_buffer.replay_buffer_size
expert_data_path = config[0].policy.collect.save_path expert_data_path = config[0].policy.collect.save_path
state_dict = torch.load(config[0].policy.learn.learner.load_path, map_location='cpu') state_dict = torch.load('./sac_seed0/ckpt/iteration_0.pth.tar', map_location='cpu')
try: try:
collect_demo_data( collect_demo_data(
config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict
...@@ -362,11 +365,10 @@ def test_discrete_cql(): ...@@ -362,11 +365,10 @@ def test_discrete_cql():
assert False, "pipeline fail" assert False, "pipeline fail"
# collect expert data # collect expert data
import torch
config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)] config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)]
collect_count = config[0].policy.other.replay_buffer.replay_buffer_size collect_count = config[0].policy.other.replay_buffer.replay_buffer_size
expert_data_path = config[0].policy.collect.save_path expert_data_path = config[0].policy.collect.save_path
state_dict = torch.load(config[0].policy.learn.learner.load_path, map_location='cpu') state_dict = torch.load('./cql_cartpole_seed0/ckpt/iteration_0.pth.tar', map_location='cpu')
try: try:
collect_demo_data( collect_demo_data(
config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict
...@@ -406,11 +408,10 @@ def test_td3_bc(): ...@@ -406,11 +408,10 @@ def test_td3_bc():
assert False, "pipeline fail" assert False, "pipeline fail"
# collect expert data # collect expert data
import torch
config = [deepcopy(pendulum_td3_generation_config), deepcopy(pendulum_td3_generation_create_config)] config = [deepcopy(pendulum_td3_generation_config), deepcopy(pendulum_td3_generation_create_config)]
collect_count = config[0].policy.other.replay_buffer.replay_buffer_size collect_count = config[0].policy.other.replay_buffer.replay_buffer_size
expert_data_path = config[0].policy.collect.save_path expert_data_path = config[0].policy.collect.save_path
state_dict = torch.load(config[0].policy.learn.learner.load_path, map_location='cpu') state_dict = torch.load('./td3_seed0/ckpt/iteration_0.pth.tar', map_location='cpu')
try: try:
collect_demo_data( collect_demo_data(
config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict
......
...@@ -71,12 +71,17 @@ class DQNPolicy(Policy): ...@@ -71,12 +71,17 @@ class DQNPolicy(Policy):
config = dict( config = dict(
type='dqn', type='dqn',
# (bool) Whether use cuda in policy
cuda=False, cuda=False,
# (bool) Whether learning policy is the same as collecting data policy(on-policy)
on_policy=False, on_policy=False,
# (bool) Whether enable priority experience sample
priority=False, priority=False,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=False, priority_IS_weight=False,
# (float) Discount factor(gamma) for returns
discount_factor=0.97, discount_factor=0.97,
# (int) The number of step for calculating target q_value
nstep=1, nstep=1,
learn=dict( learn=dict(
# (bool) Whether to use multi gpu # (bool) Whether to use multi gpu
......
...@@ -3,6 +3,7 @@ from ding.entry import serial_pipeline ...@@ -3,6 +3,7 @@ from ding.entry import serial_pipeline
from easydict import EasyDict from easydict import EasyDict
pong_qrdqn_config = dict( pong_qrdqn_config = dict(
exp_name='pong_qrdqn_generation',
env=dict( env=dict(
collector_env_num=8, collector_env_num=8,
evaluator_env_num=8, evaluator_env_num=8,
...@@ -39,7 +40,7 @@ pong_qrdqn_config = dict( ...@@ -39,7 +40,7 @@ pong_qrdqn_config = dict(
collect=dict( collect=dict(
n_sample=100, n_sample=100,
data_type='hdf5', data_type='hdf5',
save_path='./expert/expert.pkl', save_path='expert.pkl',
), ),
eval=dict(evaluator=dict(eval_freq=4000, )), eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict( other=dict(
......
...@@ -3,6 +3,7 @@ from ding.entry import serial_pipeline ...@@ -3,6 +3,7 @@ from ding.entry import serial_pipeline
from easydict import EasyDict from easydict import EasyDict
qbert_qrdqn_config = dict( qbert_qrdqn_config = dict(
exp_name='qbert_qrdqn_geneation',
env=dict( env=dict(
collector_env_num=8, collector_env_num=8,
evaluator_env_num=8, evaluator_env_num=8,
...@@ -39,7 +40,7 @@ qbert_qrdqn_config = dict( ...@@ -39,7 +40,7 @@ qbert_qrdqn_config = dict(
collect=dict( collect=dict(
n_sample=100, n_sample=100,
data_type='hdf5', data_type='hdf5',
save_path='./expert/expert.pkl', save_path='expert.pkl',
), ),
eval=dict(evaluator=dict(eval_freq=4000, )), eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict( other=dict(
......
...@@ -29,7 +29,7 @@ cartpole_discrete_cql_config = dict( ...@@ -29,7 +29,7 @@ cartpole_discrete_cql_config = dict(
), ),
collect=dict( collect=dict(
data_type='hdf5', data_type='hdf5',
data_path='./cartpole_generation/expert_demos.hdf5', data_path='./cartpole_generation_seed0/expert_demos.hdf5', # user-specific
n_sample=80, n_sample=80,
unroll_len=1, unroll_len=1,
), ),
......
...@@ -37,7 +37,7 @@ cartpole_qrdqn_generation_data_config = dict( ...@@ -37,7 +37,7 @@ cartpole_qrdqn_generation_data_config = dict(
n_sample=80, n_sample=80,
unroll_len=1, unroll_len=1,
data_type='hdf5', data_type='hdf5',
save_path='./cartpole_generation/expert.pkl', save_path='expert.pkl',
), ),
other=dict( other=dict(
eps=dict( eps=dict(
......
...@@ -37,7 +37,7 @@ pendulum_cql_default_config = dict( ...@@ -37,7 +37,7 @@ pendulum_cql_default_config = dict(
n_sample=1, n_sample=1,
unroll_len=1, unroll_len=1,
data_type='hdf5', data_type='hdf5',
data_path='./sac/expert_demos.hdf5', data_path='./peudulum_sac_generation_seed0/expert_demos.hdf5', # user-specific
), ),
command=dict(), command=dict(),
eval=dict(evaluator=dict(eval_freq=100, )), eval=dict(evaluator=dict(eval_freq=100, )),
......
from easydict import EasyDict from easydict import EasyDict
pendulum_sac_data_genearation_default_config = dict( pendulum_sac_data_genearation_default_config = dict(
exp_name='peudulum_sac_generation',
seed=0, seed=0,
env=dict( env=dict(
collector_env_num=10, collector_env_num=10,
...@@ -43,7 +44,7 @@ pendulum_sac_data_genearation_default_config = dict( ...@@ -43,7 +44,7 @@ pendulum_sac_data_genearation_default_config = dict(
collect=dict( collect=dict(
n_sample=1, n_sample=1,
unroll_len=1, unroll_len=1,
save_path='./sac/expert.pkl', save_path='expert.pkl',
data_type='hdf5', data_type='hdf5',
), ),
command=dict(), command=dict(),
......
...@@ -44,7 +44,7 @@ pendulum_td3_bc_config = dict( ...@@ -44,7 +44,7 @@ pendulum_td3_bc_config = dict(
noise_sigma=0.1, noise_sigma=0.1,
collector=dict(collect_print_freq=1000, ), collector=dict(collect_print_freq=1000, ),
data_type='hdf5', data_type='hdf5',
data_path='./td3/expert_demos.hdf5', data_path='./pendulum_td3_generation_seed0/expert_demos.hdf5', # user-specific
normalize_states=True, normalize_states=True,
), ),
eval=dict(evaluator=dict(eval_freq=100, ), ), eval=dict(evaluator=dict(eval_freq=100, ), ),
......
from easydict import EasyDict from easydict import EasyDict
pendulum_td3_generation_config = dict( pendulum_td3_generation_config = dict(
exp_name='td3', exp_name='pendulum_td3_generation',
env=dict( env=dict(
collector_env_num=8, collector_env_num=8,
evaluator_env_num=10, evaluator_env_num=10,
...@@ -45,7 +45,7 @@ pendulum_td3_generation_config = dict( ...@@ -45,7 +45,7 @@ pendulum_td3_generation_config = dict(
n_sample=10, n_sample=10,
noise_sigma=0.1, noise_sigma=0.1,
collector=dict(collect_print_freq=1000, ), collector=dict(collect_print_freq=1000, ),
save_path='./td3/expert.pkl', save_path='expert.pkl',
data_type='hdf5', data_type='hdf5',
), ),
eval=dict(evaluator=dict(eval_freq=100, ), ), eval=dict(evaluator=dict(eval_freq=100, ), ),
......
from easydict import EasyDict from easydict import EasyDict
hopper_sac_data_genearation_default_config = dict( hopper_sac_data_genearation_default_config = dict(
exp='hopper_sac_generation',
env=dict( env=dict(
env_id='Hopper-v3', env_id='Hopper-v3',
norm_obs=dict(use_norm=False, ), norm_obs=dict(use_norm=False, ),
...@@ -45,7 +46,7 @@ hopper_sac_data_genearation_default_config = dict( ...@@ -45,7 +46,7 @@ hopper_sac_data_genearation_default_config = dict(
collect=dict( collect=dict(
n_sample=1, n_sample=1,
unroll_len=1, unroll_len=1,
save_path='./default_experiment/expert_iteration_200000.pkl', save_path='expert_iteration_200000.pkl',
), ),
command=dict(), command=dict(),
eval=dict(), eval=dict(),
......
from easydict import EasyDict from easydict import EasyDict
halfcheetah_td3_default_config = dict( halfcheetah_td3_default_config = dict(
exp_name='halfcheetah_td3_generation',
env=dict( env=dict(
env_id='Hopper-v3', env_id='Hopper-v3',
norm_obs=dict(use_norm=False, ), norm_obs=dict(use_norm=False, ),
...@@ -49,7 +50,7 @@ halfcheetah_td3_default_config = dict( ...@@ -49,7 +50,7 @@ halfcheetah_td3_default_config = dict(
n_sample=1, n_sample=1,
unroll_len=1, unroll_len=1,
noise_sigma=0.1, noise_sigma=0.1,
save_path='./td3/expert.pkl', save_path='expert.pkl',
data_type='hdf5', data_type='hdf5',
), ),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册