提交 ae6ab6c7 编写于 作者: N niuyazhe

fix(nyz): fix exp_name seedx name bug with data generation path

上级 35241df3
......@@ -154,6 +154,7 @@ def collect_demo_data(
if cfg.policy.cuda:
exp_data = to_device(exp_data, 'cpu')
# Save data transitions.
expert_data_path = os.path.join(cfg.exp_name, expert_data_path)
offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive'))
print('Collect demo data successfully')
......@@ -227,6 +228,7 @@ def collect_episodic_demo_data(
if cfg.policy.cuda:
exp_data = to_device(exp_data, 'cpu')
# Save data transitions.
expert_data_path = os.path.join(cfg.exp_name, expert_data_path)
offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive'))
print('Collect episodic demo data successfully')
......
......@@ -2,6 +2,7 @@ import pytest
import time
import os
from copy import deepcopy
import torch
from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_offline
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
......@@ -360,7 +361,9 @@ def test_sqn():
@pytest.mark.unittest
def test_selfplay():
try:
selfplay_main(deepcopy(league_demo_ppo_config), seed=0, max_iterations=1)
config = deepcopy(league_demo_ppo_config)
config.exp_name = 'test_selfplay'
selfplay_main(config, seed=0, max_iterations=1)
except Exception:
assert False, "pipeline fail"
......@@ -368,7 +371,9 @@ def test_selfplay():
@pytest.mark.unittest
def test_league():
try:
league_main(deepcopy(league_demo_ppo_config), seed=0, max_iterations=1)
config = deepcopy(league_demo_ppo_config)
config.exp_name = 'test_league'
league_main(config, seed=0, max_iterations=1)
except Exception as e:
assert False, "pipeline fail"
......@@ -395,14 +400,13 @@ def test_cql():
assert False, "pipeline fail"
# collect expert data
import torch
config = [
deepcopy(pendulum_sac_data_genearation_default_config),
deepcopy(pendulum_sac_data_genearation_default_create_config)
]
collect_count = 1000
expert_data_path = config[0].policy.collect.save_path
state_dict = torch.load('./sac/ckpt/iteration_0.pth.tar', map_location='cpu')
state_dict = torch.load('./sac_seed0/ckpt/iteration_0.pth.tar', map_location='cpu')
try:
collect_demo_data(
config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict
......@@ -442,11 +446,10 @@ def test_discrete_cql():
except Exception:
assert False, "pipeline fail"
# collect expert data
import torch
config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)]
collect_count = 1000
expert_data_path = config[0].policy.collect.save_path
state_dict = torch.load('./cql_cartpole/ckpt/iteration_0.pth.tar', map_location='cpu')
state_dict = torch.load('./cql_cartpole_seed0/ckpt/iteration_0.pth.tar', map_location='cpu')
try:
collect_demo_data(
config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict
......@@ -467,7 +470,7 @@ def test_discrete_cql():
os.popen('rm -rf cartpole cartpole_cql')
@pytest.mark.algotest
@pytest.mark.unittest
def test_td3_bc():
# train expert
config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)]
......@@ -479,11 +482,10 @@ def test_td3_bc():
assert False, "pipeline fail"
# collect expert data
import torch
config = [deepcopy(pendulum_td3_generation_config), deepcopy(pendulum_td3_generation_create_config)]
collect_count = 1000
expert_data_path = config[0].policy.collect.save_path
state_dict = torch.load('./td3/ckpt/iteration_0.pth.tar', map_location='cpu')
state_dict = torch.load('./td3_seed0/ckpt/iteration_0.pth.tar', map_location='cpu')
try:
collect_demo_data(
config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict
......
......@@ -281,7 +281,9 @@ def test_acer():
@pytest.mark.algotest
def test_selfplay():
try:
selfplay_main(deepcopy(league_demo_ppo_config), seed=0)
config = deepcopy(league_demo_ppo_config)
config.exp_name = 'test_selfplay'
selfplay_main(config, seed=0)
except Exception:
assert False, "pipeline fail"
with open("./algo_record.log", "a+") as f:
......@@ -291,7 +293,9 @@ def test_selfplay():
@pytest.mark.algotest
def test_league():
try:
league_main(deepcopy(league_demo_ppo_config), seed=0)
config = deepcopy(league_demo_ppo_config)
config.exp_name = 'test_league'
league_main(config, seed=0)
except Exception:
assert False, "pipeline fail"
with open("./algo_record.log", "a+") as f:
......@@ -326,14 +330,13 @@ def test_cql():
assert False, "pipeline fail"
# collect expert data
import torch
config = [
deepcopy(pendulum_sac_data_genearation_default_config),
deepcopy(pendulum_sac_data_genearation_default_create_config)
]
collect_count = config[0].policy.other.replay_buffer.replay_buffer_size
expert_data_path = config[0].policy.collect.save_path
state_dict = torch.load(config[0].policy.learn.learner.load_path, map_location='cpu')
state_dict = torch.load('./sac_seed0/ckpt/iteration_0.pth.tar', map_location='cpu')
try:
collect_demo_data(
config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict
......@@ -362,11 +365,10 @@ def test_discrete_cql():
assert False, "pipeline fail"
# collect expert data
import torch
config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)]
collect_count = config[0].policy.other.replay_buffer.replay_buffer_size
expert_data_path = config[0].policy.collect.save_path
state_dict = torch.load(config[0].policy.learn.learner.load_path, map_location='cpu')
state_dict = torch.load('./cql_cartpole_seed0/ckpt/iteration_0.pth.tar', map_location='cpu')
try:
collect_demo_data(
config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict
......@@ -406,11 +408,10 @@ def test_td3_bc():
assert False, "pipeline fail"
# collect expert data
import torch
config = [deepcopy(pendulum_td3_generation_config), deepcopy(pendulum_td3_generation_create_config)]
collect_count = config[0].policy.other.replay_buffer.replay_buffer_size
expert_data_path = config[0].policy.collect.save_path
state_dict = torch.load(config[0].policy.learn.learner.load_path, map_location='cpu')
state_dict = torch.load('./td3_seed0/ckpt/iteration_0.pth.tar', map_location='cpu')
try:
collect_demo_data(
config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict
......
......@@ -71,12 +71,17 @@ class DQNPolicy(Policy):
config = dict(
type='dqn',
# (bool) Whether use cuda in policy
cuda=False,
# (bool) Whether learning policy is the same as collecting data policy(on-policy)
on_policy=False,
# (bool) Whether enable priority experience sample
priority=False,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=False,
# (float) Discount factor(gamma) for returns
discount_factor=0.97,
# (int) The number of step for calculating target q_value
nstep=1,
learn=dict(
# (bool) Whether to use multi gpu
......
......@@ -3,6 +3,7 @@ from ding.entry import serial_pipeline
from easydict import EasyDict
pong_qrdqn_config = dict(
exp_name='pong_qrdqn_generation',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
......@@ -39,7 +40,7 @@ pong_qrdqn_config = dict(
collect=dict(
n_sample=100,
data_type='hdf5',
save_path='./expert/expert.pkl',
save_path='expert.pkl',
),
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
......
......@@ -3,6 +3,7 @@ from ding.entry import serial_pipeline
from easydict import EasyDict
qbert_qrdqn_config = dict(
exp_name='qbert_qrdqn_geneation',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
......@@ -39,7 +40,7 @@ qbert_qrdqn_config = dict(
collect=dict(
n_sample=100,
data_type='hdf5',
save_path='./expert/expert.pkl',
save_path='expert.pkl',
),
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
......
......@@ -29,7 +29,7 @@ cartpole_discrete_cql_config = dict(
),
collect=dict(
data_type='hdf5',
data_path='./cartpole_generation/expert_demos.hdf5',
data_path='./cartpole_generation_seed0/expert_demos.hdf5', # user-specific
n_sample=80,
unroll_len=1,
),
......
......@@ -37,7 +37,7 @@ cartpole_qrdqn_generation_data_config = dict(
n_sample=80,
unroll_len=1,
data_type='hdf5',
save_path='./cartpole_generation/expert.pkl',
save_path='expert.pkl',
),
other=dict(
eps=dict(
......
......@@ -37,7 +37,7 @@ pendulum_cql_default_config = dict(
n_sample=1,
unroll_len=1,
data_type='hdf5',
data_path='./sac/expert_demos.hdf5',
data_path='./peudulum_sac_generation_seed0/expert_demos.hdf5', # user-specific
),
command=dict(),
eval=dict(evaluator=dict(eval_freq=100, )),
......
from easydict import EasyDict
pendulum_sac_data_genearation_default_config = dict(
exp_name='peudulum_sac_generation',
seed=0,
env=dict(
collector_env_num=10,
......@@ -43,7 +44,7 @@ pendulum_sac_data_genearation_default_config = dict(
collect=dict(
n_sample=1,
unroll_len=1,
save_path='./sac/expert.pkl',
save_path='expert.pkl',
data_type='hdf5',
),
command=dict(),
......
......@@ -44,7 +44,7 @@ pendulum_td3_bc_config = dict(
noise_sigma=0.1,
collector=dict(collect_print_freq=1000, ),
data_type='hdf5',
data_path='./td3/expert_demos.hdf5',
data_path='./pendulum_td3_generation_seed0/expert_demos.hdf5', # user-specific
normalize_states=True,
),
eval=dict(evaluator=dict(eval_freq=100, ), ),
......
from easydict import EasyDict
pendulum_td3_generation_config = dict(
exp_name='td3',
exp_name='pendulum_td3_generation',
env=dict(
collector_env_num=8,
evaluator_env_num=10,
......@@ -45,7 +45,7 @@ pendulum_td3_generation_config = dict(
n_sample=10,
noise_sigma=0.1,
collector=dict(collect_print_freq=1000, ),
save_path='./td3/expert.pkl',
save_path='expert.pkl',
data_type='hdf5',
),
eval=dict(evaluator=dict(eval_freq=100, ), ),
......
from easydict import EasyDict
hopper_sac_data_genearation_default_config = dict(
exp='hopper_sac_generation',
env=dict(
env_id='Hopper-v3',
norm_obs=dict(use_norm=False, ),
......@@ -45,7 +46,7 @@ hopper_sac_data_genearation_default_config = dict(
collect=dict(
n_sample=1,
unroll_len=1,
save_path='./default_experiment/expert_iteration_200000.pkl',
save_path='expert_iteration_200000.pkl',
),
command=dict(),
eval=dict(),
......
from easydict import EasyDict
halfcheetah_td3_default_config = dict(
exp_name='halfcheetah_td3_generation',
env=dict(
env_id='Hopper-v3',
norm_obs=dict(use_norm=False, ),
......@@ -49,7 +50,7 @@ halfcheetah_td3_default_config = dict(
n_sample=1,
unroll_len=1,
noise_sigma=0.1,
save_path='./td3/expert.pkl',
save_path='expert.pkl',
data_type='hdf5',
),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册