From ae6ab6c7d6f595b0ed53edc97df58ac0733786cc Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Sat, 1 Jan 2022 21:57:49 +0800 Subject: [PATCH] fix(nyz): fix exp_name seedx name bug with data generation path --- ding/entry/application_entry.py | 2 ++ ding/entry/tests/test_serial_entry.py | 20 ++++++++++--------- ding/entry/tests/test_serial_entry_algo.py | 17 ++++++++-------- ding/policy/dqn.py | 5 +++++ .../pong/pong_qrdqn_generation_data_config.py | 3 ++- .../qbert_qrdqn_generation_data_config.py | 3 ++- .../cartpole/config/cartpole_cql_config.py | 2 +- .../cartpole_qrdqn_generation_data_config.py | 2 +- .../pendulum/config/pendulum_cql_config.py | 2 +- ...ulum_sac_data_generation_default_config.py | 3 ++- .../pendulum/config/pendulum_td3_bc_config.py | 2 +- .../pendulum_td3_data_generation_config.py | 4 ++-- ...pper_sac_data_generation_default_config.py | 3 ++- .../hopper_td3_data_generation_config.py | 3 ++- 14 files changed, 43 insertions(+), 28 deletions(-) diff --git a/ding/entry/application_entry.py b/ding/entry/application_entry.py index 0d77c39..2367b8a 100644 --- a/ding/entry/application_entry.py +++ b/ding/entry/application_entry.py @@ -154,6 +154,7 @@ def collect_demo_data( if cfg.policy.cuda: exp_data = to_device(exp_data, 'cpu') # Save data transitions. + expert_data_path = os.path.join(cfg.exp_name, expert_data_path) offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive')) print('Collect demo data successfully') @@ -227,6 +228,7 @@ def collect_episodic_demo_data( if cfg.policy.cuda: exp_data = to_device(exp_data, 'cpu') # Save data transitions. + expert_data_path = os.path.join(cfg.exp_name, expert_data_path) offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive')) print('Collect episodic demo data successfully') diff --git a/ding/entry/tests/test_serial_entry.py b/ding/entry/tests/test_serial_entry.py index d80f4b8..b8c0904 100644 --- a/ding/entry/tests/test_serial_entry.py +++ b/ding/entry/tests/test_serial_entry.py @@ -2,6 +2,7 @@ import pytest import time import os from copy import deepcopy +import torch from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_offline from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config @@ -360,7 +361,9 @@ def test_sqn(): @pytest.mark.unittest def test_selfplay(): try: - selfplay_main(deepcopy(league_demo_ppo_config), seed=0, max_iterations=1) + config = deepcopy(league_demo_ppo_config) + config.exp_name = 'test_selfplay' + selfplay_main(config, seed=0, max_iterations=1) except Exception: assert False, "pipeline fail" @@ -368,7 +371,9 @@ def test_selfplay(): @pytest.mark.unittest def test_league(): try: - league_main(deepcopy(league_demo_ppo_config), seed=0, max_iterations=1) + config = deepcopy(league_demo_ppo_config) + config.exp_name = 'test_league' + league_main(config, seed=0, max_iterations=1) except Exception as e: assert False, "pipeline fail" @@ -395,14 +400,13 @@ def test_cql(): assert False, "pipeline fail" # collect expert data - import torch config = [ deepcopy(pendulum_sac_data_genearation_default_config), deepcopy(pendulum_sac_data_genearation_default_create_config) ] collect_count = 1000 expert_data_path = config[0].policy.collect.save_path - state_dict = torch.load('./sac/ckpt/iteration_0.pth.tar', map_location='cpu') + state_dict = torch.load('./sac_seed0/ckpt/iteration_0.pth.tar', map_location='cpu') try: collect_demo_data( config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict @@ -442,11 +446,10 @@ def test_discrete_cql(): except Exception: assert False, "pipeline fail" # collect expert data - import torch config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)] collect_count = 1000 expert_data_path = config[0].policy.collect.save_path - state_dict = torch.load('./cql_cartpole/ckpt/iteration_0.pth.tar', map_location='cpu') + state_dict = torch.load('./cql_cartpole_seed0/ckpt/iteration_0.pth.tar', map_location='cpu') try: collect_demo_data( config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict @@ -467,7 +470,7 @@ def test_discrete_cql(): os.popen('rm -rf cartpole cartpole_cql') -@pytest.mark.algotest +@pytest.mark.unittest def test_td3_bc(): # train expert config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)] @@ -479,11 +482,10 @@ def test_td3_bc(): assert False, "pipeline fail" # collect expert data - import torch config = [deepcopy(pendulum_td3_generation_config), deepcopy(pendulum_td3_generation_create_config)] collect_count = 1000 expert_data_path = config[0].policy.collect.save_path - state_dict = torch.load('./td3/ckpt/iteration_0.pth.tar', map_location='cpu') + state_dict = torch.load('./td3_seed0/ckpt/iteration_0.pth.tar', map_location='cpu') try: collect_demo_data( config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict diff --git a/ding/entry/tests/test_serial_entry_algo.py b/ding/entry/tests/test_serial_entry_algo.py index e8f743b..82d527b 100644 --- a/ding/entry/tests/test_serial_entry_algo.py +++ b/ding/entry/tests/test_serial_entry_algo.py @@ -281,7 +281,9 @@ def test_acer(): @pytest.mark.algotest def test_selfplay(): try: - selfplay_main(deepcopy(league_demo_ppo_config), seed=0) + config = deepcopy(league_demo_ppo_config) + config.exp_name = 'test_selfplay' + selfplay_main(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: @@ -291,7 +293,9 @@ def test_selfplay(): @pytest.mark.algotest def test_league(): try: - league_main(deepcopy(league_demo_ppo_config), seed=0) + config = deepcopy(league_demo_ppo_config) + config.exp_name = 'test_league' + league_main(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: @@ -326,14 +330,13 @@ def test_cql(): assert False, "pipeline fail" # collect expert data - import torch config = [ deepcopy(pendulum_sac_data_genearation_default_config), deepcopy(pendulum_sac_data_genearation_default_create_config) ] collect_count = config[0].policy.other.replay_buffer.replay_buffer_size expert_data_path = config[0].policy.collect.save_path - state_dict = torch.load(config[0].policy.learn.learner.load_path, map_location='cpu') + state_dict = torch.load('./sac_seed0/ckpt/iteration_0.pth.tar', map_location='cpu') try: collect_demo_data( config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict @@ -362,11 +365,10 @@ def test_discrete_cql(): assert False, "pipeline fail" # collect expert data - import torch config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)] collect_count = config[0].policy.other.replay_buffer.replay_buffer_size expert_data_path = config[0].policy.collect.save_path - state_dict = torch.load(config[0].policy.learn.learner.load_path, map_location='cpu') + state_dict = torch.load('./cql_cartpole_seed0/ckpt/iteration_0.pth.tar', map_location='cpu') try: collect_demo_data( config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict @@ -406,11 +408,10 @@ def test_td3_bc(): assert False, "pipeline fail" # collect expert data - import torch config = [deepcopy(pendulum_td3_generation_config), deepcopy(pendulum_td3_generation_create_config)] collect_count = config[0].policy.other.replay_buffer.replay_buffer_size expert_data_path = config[0].policy.collect.save_path - state_dict = torch.load(config[0].policy.learn.learner.load_path, map_location='cpu') + state_dict = torch.load('./td3_seed0/ckpt/iteration_0.pth.tar', map_location='cpu') try: collect_demo_data( config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict diff --git a/ding/policy/dqn.py b/ding/policy/dqn.py index fe646a8..456114a 100644 --- a/ding/policy/dqn.py +++ b/ding/policy/dqn.py @@ -71,12 +71,17 @@ class DQNPolicy(Policy): config = dict( type='dqn', + # (bool) Whether use cuda in policy cuda=False, + # (bool) Whether learning policy is the same as collecting data policy(on-policy) on_policy=False, + # (bool) Whether enable priority experience sample priority=False, # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, + # (float) Discount factor(gamma) for returns discount_factor=0.97, + # (int) The number of step for calculating target q_value nstep=1, learn=dict( # (bool) Whether to use multi gpu diff --git a/dizoo/atari/config/serial/pong/pong_qrdqn_generation_data_config.py b/dizoo/atari/config/serial/pong/pong_qrdqn_generation_data_config.py index f366c85..6bc8a6c 100644 --- a/dizoo/atari/config/serial/pong/pong_qrdqn_generation_data_config.py +++ b/dizoo/atari/config/serial/pong/pong_qrdqn_generation_data_config.py @@ -3,6 +3,7 @@ from ding.entry import serial_pipeline from easydict import EasyDict pong_qrdqn_config = dict( + exp_name='pong_qrdqn_generation', env=dict( collector_env_num=8, evaluator_env_num=8, @@ -39,7 +40,7 @@ pong_qrdqn_config = dict( collect=dict( n_sample=100, data_type='hdf5', - save_path='./expert/expert.pkl', + save_path='expert.pkl', ), eval=dict(evaluator=dict(eval_freq=4000, )), other=dict( diff --git a/dizoo/atari/config/serial/qbert/qbert_qrdqn_generation_data_config.py b/dizoo/atari/config/serial/qbert/qbert_qrdqn_generation_data_config.py index e802e33..487470d 100644 --- a/dizoo/atari/config/serial/qbert/qbert_qrdqn_generation_data_config.py +++ b/dizoo/atari/config/serial/qbert/qbert_qrdqn_generation_data_config.py @@ -3,6 +3,7 @@ from ding.entry import serial_pipeline from easydict import EasyDict qbert_qrdqn_config = dict( + exp_name='qbert_qrdqn_geneation', env=dict( collector_env_num=8, evaluator_env_num=8, @@ -39,7 +40,7 @@ qbert_qrdqn_config = dict( collect=dict( n_sample=100, data_type='hdf5', - save_path='./expert/expert.pkl', + save_path='expert.pkl', ), eval=dict(evaluator=dict(eval_freq=4000, )), other=dict( diff --git a/dizoo/classic_control/cartpole/config/cartpole_cql_config.py b/dizoo/classic_control/cartpole/config/cartpole_cql_config.py index 8fa011e..e39405b 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_cql_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_cql_config.py @@ -29,7 +29,7 @@ cartpole_discrete_cql_config = dict( ), collect=dict( data_type='hdf5', - data_path='./cartpole_generation/expert_demos.hdf5', + data_path='./cartpole_generation_seed0/expert_demos.hdf5', # user-specific n_sample=80, unroll_len=1, ), diff --git a/dizoo/classic_control/cartpole/config/cartpole_qrdqn_generation_data_config.py b/dizoo/classic_control/cartpole/config/cartpole_qrdqn_generation_data_config.py index 0d6450b..2924693 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_qrdqn_generation_data_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_qrdqn_generation_data_config.py @@ -37,7 +37,7 @@ cartpole_qrdqn_generation_data_config = dict( n_sample=80, unroll_len=1, data_type='hdf5', - save_path='./cartpole_generation/expert.pkl', + save_path='expert.pkl', ), other=dict( eps=dict( diff --git a/dizoo/classic_control/pendulum/config/pendulum_cql_config.py b/dizoo/classic_control/pendulum/config/pendulum_cql_config.py index 7e91497..d14378b 100644 --- a/dizoo/classic_control/pendulum/config/pendulum_cql_config.py +++ b/dizoo/classic_control/pendulum/config/pendulum_cql_config.py @@ -37,7 +37,7 @@ pendulum_cql_default_config = dict( n_sample=1, unroll_len=1, data_type='hdf5', - data_path='./sac/expert_demos.hdf5', + data_path='./peudulum_sac_generation_seed0/expert_demos.hdf5', # user-specific ), command=dict(), eval=dict(evaluator=dict(eval_freq=100, )), diff --git a/dizoo/classic_control/pendulum/config/pendulum_sac_data_generation_default_config.py b/dizoo/classic_control/pendulum/config/pendulum_sac_data_generation_default_config.py index a673e8b..76e9e6e 100644 --- a/dizoo/classic_control/pendulum/config/pendulum_sac_data_generation_default_config.py +++ b/dizoo/classic_control/pendulum/config/pendulum_sac_data_generation_default_config.py @@ -1,6 +1,7 @@ from easydict import EasyDict pendulum_sac_data_genearation_default_config = dict( + exp_name='peudulum_sac_generation', seed=0, env=dict( collector_env_num=10, @@ -43,7 +44,7 @@ pendulum_sac_data_genearation_default_config = dict( collect=dict( n_sample=1, unroll_len=1, - save_path='./sac/expert.pkl', + save_path='expert.pkl', data_type='hdf5', ), command=dict(), diff --git a/dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py b/dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py index 200aa63..9e95815 100644 --- a/dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py +++ b/dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py @@ -44,7 +44,7 @@ pendulum_td3_bc_config = dict( noise_sigma=0.1, collector=dict(collect_print_freq=1000, ), data_type='hdf5', - data_path='./td3/expert_demos.hdf5', + data_path='./pendulum_td3_generation_seed0/expert_demos.hdf5', # user-specific normalize_states=True, ), eval=dict(evaluator=dict(eval_freq=100, ), ), diff --git a/dizoo/classic_control/pendulum/config/pendulum_td3_data_generation_config.py b/dizoo/classic_control/pendulum/config/pendulum_td3_data_generation_config.py index 3574434..f3d6694 100644 --- a/dizoo/classic_control/pendulum/config/pendulum_td3_data_generation_config.py +++ b/dizoo/classic_control/pendulum/config/pendulum_td3_data_generation_config.py @@ -1,7 +1,7 @@ from easydict import EasyDict pendulum_td3_generation_config = dict( - exp_name='td3', + exp_name='pendulum_td3_generation', env=dict( collector_env_num=8, evaluator_env_num=10, @@ -45,7 +45,7 @@ pendulum_td3_generation_config = dict( n_sample=10, noise_sigma=0.1, collector=dict(collect_print_freq=1000, ), - save_path='./td3/expert.pkl', + save_path='expert.pkl', data_type='hdf5', ), eval=dict(evaluator=dict(eval_freq=100, ), ), diff --git a/dizoo/mujoco/config/hopper_sac_data_generation_default_config.py b/dizoo/mujoco/config/hopper_sac_data_generation_default_config.py index 6a126d8..8aa8647 100644 --- a/dizoo/mujoco/config/hopper_sac_data_generation_default_config.py +++ b/dizoo/mujoco/config/hopper_sac_data_generation_default_config.py @@ -1,6 +1,7 @@ from easydict import EasyDict hopper_sac_data_genearation_default_config = dict( + exp='hopper_sac_generation', env=dict( env_id='Hopper-v3', norm_obs=dict(use_norm=False, ), @@ -45,7 +46,7 @@ hopper_sac_data_genearation_default_config = dict( collect=dict( n_sample=1, unroll_len=1, - save_path='./default_experiment/expert_iteration_200000.pkl', + save_path='expert_iteration_200000.pkl', ), command=dict(), eval=dict(), diff --git a/dizoo/mujoco/config/hopper_td3_data_generation_config.py b/dizoo/mujoco/config/hopper_td3_data_generation_config.py index 9a0f71c..31bd9e9 100644 --- a/dizoo/mujoco/config/hopper_td3_data_generation_config.py +++ b/dizoo/mujoco/config/hopper_td3_data_generation_config.py @@ -1,6 +1,7 @@ from easydict import EasyDict halfcheetah_td3_default_config = dict( + exp_name='halfcheetah_td3_generation', env=dict( env_id='Hopper-v3', norm_obs=dict(use_norm=False, ), @@ -49,7 +50,7 @@ halfcheetah_td3_default_config = dict( n_sample=1, unroll_len=1, noise_sigma=0.1, - save_path='./td3/expert.pkl', + save_path='expert.pkl', data_type='hdf5', ), other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), -- GitLab