提交 4e833da2 编写于 作者: N niuyazhe

polish(nyz): polish cartpole ppo demo and related unittest

上级 3243c92d
......@@ -13,7 +13,7 @@ from easydict import EasyDict
from ding.utils import deep_merge_dicts
from ding.envs import get_env_cls, get_env_manager_cls
from ding.policy import get_policy_cls
from ding.worker import BaseLearner, BaseSerialEvaluator, BaseSerialCommander, Coordinator, \
from ding.worker import BaseLearner, BaseSerialEvaluator, BaseSerialCommander, Coordinator, AdvancedReplayBuffer, \
get_parallel_commander_cls, get_parallel_collector_cls, get_buffer_cls, get_serial_collector_cls
from ding.reward_model import get_reward_model_cls
from .utils import parallel_transform, parallel_transform_slurm, parallel_transform_k8s, save_config_formatted
......@@ -309,7 +309,7 @@ def compile_config(
learner: type = BaseLearner,
collector: type = None,
evaluator: type = BaseSerialEvaluator,
buffer: type = None,
buffer: type = AdvancedReplayBuffer,
env: type = None,
reward_model: type = None,
seed: int = 0,
......
......@@ -3,7 +3,7 @@ import pytest
import os
import pickle
from dizoo.classic_control.cartpole.config.cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config # noqa
from dizoo.classic_control.cartpole.envs import CartPoleEnv
from ding.entry import serial_pipeline, eval, collect_demo_data
from ding.config import compile_config
......@@ -11,7 +11,7 @@ from ding.config import compile_config
@pytest.fixture(scope='module')
def setup_state_dict():
config = deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
try:
policy = serial_pipeline(config, seed=0)
except Exception:
......@@ -27,12 +27,14 @@ def setup_state_dict():
class TestApplication:
def test_eval(self, setup_state_dict):
cfg_for_stop_value = compile_config(cartpole_ppo_config, auto=True, create_cfg=cartpole_ppo_create_config)
cfg_for_stop_value = compile_config(
cartpole_ppo_offpolicy_config, auto=True, create_cfg=cartpole_ppo_offpolicy_create_config
)
stop_value = cfg_for_stop_value.env.stop_value
config = deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
eval_reward = eval(config, seed=0, state_dict=setup_state_dict['eval'])
assert eval_reward >= stop_value
config = deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
eval_reward = eval(
config,
seed=0,
......@@ -42,7 +44,7 @@ class TestApplication:
assert eval_reward >= stop_value
def test_collect_demo_data(self, setup_state_dict):
config = deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
collect_count = 16
expert_data_path = './expert.data'
collect_demo_data(
......
......@@ -15,6 +15,7 @@ from dizoo.classic_control.cartpole.config.cartpole_qrdqn_config import cartpole
from dizoo.classic_control.cartpole.config.cartpole_sqn_config import cartpole_sqn_config, cartpole_sqn_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_ppg_config import cartpole_ppg_config, cartpole_ppg_create_config # noqa
from dizoo.classic_control.cartpole.entry.cartpole_ppg_main import main as ppg_main
from dizoo.classic_control.cartpole.entry.cartpole_ppo_main import main as ppo_main
from dizoo.classic_control.cartpole.config.cartpole_r2d2_config import cartpole_r2d2_config, cartpole_r2d2_create_config # noqa
from dizoo.classic_control.pendulum.config import pendulum_ddpg_config, pendulum_ddpg_create_config
from dizoo.classic_control.pendulum.config import pendulum_td3_config, pendulum_td3_create_config
......@@ -116,7 +117,7 @@ def test_ppo():
config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
config[0].policy.learn.update_per_collect = 1
try:
serial_pipeline(config, seed=0, max_iterations=1)
ppo_main(config[0], seed=0, max_iterations=1)
except Exception:
assert False, "pipeline fail"
......
......@@ -15,6 +15,7 @@ from dizoo.classic_control.cartpole.config.cartpole_qrdqn_config import cartpole
from dizoo.classic_control.cartpole.config.cartpole_sqn_config import cartpole_sqn_config, cartpole_sqn_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_ppg_config import cartpole_ppg_config, cartpole_ppg_create_config # noqa
from dizoo.classic_control.cartpole.entry.cartpole_ppg_main import main as ppg_main
from dizoo.classic_control.cartpole.entry.cartpole_ppo_main import main as ppo_main
from dizoo.classic_control.cartpole.config.cartpole_r2d2_config import cartpole_r2d2_config, cartpole_r2d2_create_config # noqa
from dizoo.classic_control.pendulum.config import pendulum_ddpg_config, pendulum_ddpg_create_config
from dizoo.classic_control.pendulum.config import pendulum_td3_config, pendulum_td3_create_config
......@@ -90,7 +91,7 @@ def test_rainbow():
def test_ppo():
config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
try:
serial_pipeline(config, seed=0)
ppo_main(config[0], seed=0)
except Exception:
assert False, "pipeline fail"
with open("./algo_record.log", "a+") as f:
......@@ -260,4 +261,4 @@ def test_qrdqn():
except Exception:
assert False, "pipeline fail"
with open("./algo_record.log", "a+") as f:
f.write("20. sqn\n")
f.write("21. qrdqn\n")
......@@ -6,21 +6,21 @@ import torch
from collections import namedtuple
import os
from dizoo.classic_control.cartpole.config import cartpole_ppo_config, cartpole_ppo_create_config, \
cartpole_dqn_config, cartpole_dqn_create_config
from ding.torch_utils import Adam, to_device
from ding.config import compile_config
from ding.model import model_wrap
from ding.rl_utils import get_train_sample, get_nstep_return_data
from ding.entry import serial_pipeline_il, collect_demo_data, serial_pipeline
from ding.policy import PPOPolicy, ILPolicy
from ding.policy import PPOOffPolicy, ILPolicy
from ding.policy.common_utils import default_preprocess_learn
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from dizoo.classic_control.cartpole.config import cartpole_dqn_config, cartpole_dqn_create_config, \
cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config
@POLICY_REGISTRY.register('ppo_il')
class PPOILPolicy(PPOPolicy):
class PPOILPolicy(PPOOffPolicy):
def _forward_learn(self, data: dict) -> dict:
data = default_preprocess_learn(data, ignore_done=self._cfg.learn.get('ignore_done', False), use_nstep=False)
......@@ -46,20 +46,20 @@ class PPOILPolicy(PPOPolicy):
@pytest.mark.unittest
def test_serial_pipeline_il_ppo():
# train expert policy
train_config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
train_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
expert_policy = serial_pipeline(train_config, seed=0)
# collect expert demo data
collect_count = 10000
expert_data_path = 'expert_data_ppo.pkl'
state_dict = expert_policy.collect_mode.state_dict()
collect_config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
collect_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
collect_demo_data(
collect_config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
)
# il training 1
il_config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
il_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
il_config[0].policy.learn.train_epoch = 10
il_config[0].policy.type = 'ppo_il'
_, converge_stop_flag = serial_pipeline_il(il_config, seed=314, data_path=expert_data_path)
......
......@@ -5,7 +5,7 @@ from easydict import EasyDict
from copy import deepcopy
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_ppo_rnd_config import cartpole_ppo_rnd_config, cartpole_ppo_rnd_create_config # noqa
from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model
......@@ -42,13 +42,13 @@ cfg = [
@pytest.mark.parametrize('reward_model_config', cfg)
def test_irl(reward_model_config):
reward_model_config = EasyDict(reward_model_config)
config = deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
expert_policy = serial_pipeline(config, seed=0, max_iterations=2)
# collect expert demo data
collect_count = 10000
expert_data_path = 'expert_data.pkl'
state_dict = expert_policy.collect_mode.state_dict()
config = deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
collect_demo_data(
config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
)
......
......@@ -111,7 +111,7 @@ class PPOCommandModePolicy(PPOPolicy, DummyCommandModePolicy):
@POLICY_REGISTRY.register('ppo_offpolicy_command')
class PPOCommandModePolicy(PPOOffPolicy, DummyCommandModePolicy):
class PPOOffCommandModePolicy(PPOOffPolicy, DummyCommandModePolicy):
pass
......
......@@ -220,7 +220,7 @@ class PPOPolicy(Policy):
'value_max': output['value'].max().item(),
'approx_kl': ppo_info.approx_kl,
'clipfrac': ppo_info.clipfrac,
'act': batch['action'].mean().item(),
'act': batch['action'].float().mean().item(),
}
if self._continuous:
return_info.update(
......@@ -326,7 +326,7 @@ class PPOPolicy(Policy):
else:
with torch.no_grad():
last_value = self._collect_model.forward(
data[-1]['next_obs'].unsqueeze(0), mode='compute_critic'
data[-1]['next_obs'].unsqueeze(0), mode='compute_actor_critic'
)['value']
if self._value_norm:
last_value *= self._running_mean_std.std
......@@ -458,9 +458,7 @@ class PPOOffPolicy(Policy):
gae_lambda=0.95,
),
eval=dict(),
# Although ppo is an on-policy algorithm, ding reuses the buffer mechanism, and clear buffer after update.
# Note replay_buffer_size must be greater than n_sample.
other=dict(replay_buffer=dict(replay_buffer_size=1000, ), ),
other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ),
)
def _init_learn(self) -> None:
......
......@@ -2,6 +2,7 @@ from .cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
from .cartpole_a2c_config import cartpole_a2c_config, cartpole_a2c_create_config
from .cartpole_impala_config import cartpole_impala_config, cartpole_impala_create_config
from .cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config
from .cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config
from .cartpole_rainbow_config import cartpole_rainbow_config, cartpole_rainbow_create_config
from .cartpole_iqn_config import cartpole_iqn_config, cartpole_iqn_create_config
from .cartpole_c51_config import cartpole_c51_config, cartpole_c51_create_config
......
......@@ -10,6 +10,8 @@ cartpole_ppo_config = dict(
),
policy=dict(
cuda=False,
on_policy=True,
continuous=False,
model=dict(
obs_shape=4,
action_shape=2,
......@@ -18,7 +20,7 @@ cartpole_ppo_config = dict(
actor_head_hidden_size=128,
),
learn=dict(
update_per_collect=6,
epoch_per_collect=2,
batch_size=64,
learning_rate=0.001,
value_weight=0.5,
......@@ -26,7 +28,7 @@ cartpole_ppo_config = dict(
clip_ratio=0.2,
),
collect=dict(
n_sample=128,
n_sample=256,
unroll_len=1,
discount_factor=0.9,
gae_lambda=0.95,
......
from easydict import EasyDict
cartpole_ppo_offpolicy_config = dict(
exp_name='cartpole_ppo_offpolicy',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=195,
),
policy=dict(
on_policy=False,
cuda=False,
model=dict(
obs_shape=4,
action_shape=2,
encoder_hidden_size_list=[64, 64, 128],
critic_head_hidden_size=128,
actor_head_hidden_size=128,
),
learn=dict(
update_per_collect=6,
batch_size=64,
learning_rate=0.001,
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
),
collect=dict(
n_sample=128,
unroll_len=1,
discount_factor=0.9,
gae_lambda=0.95,
),
other=dict(replay_buffer=dict(replay_buffer_size=5000))
),
)
cartpole_ppo_offpolicy_config = EasyDict(cartpole_ppo_offpolicy_config)
main_config = cartpole_ppo_offpolicy_config
cartpole_ppo_offpolicy_create_config = dict(
env=dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo_offpolicy'),
)
cartpole_ppo_offpolicy_create_config = EasyDict(cartpole_ppo_offpolicy_create_config)
create_config = cartpole_ppo_offpolicy_create_config
......@@ -17,6 +17,7 @@ cartpole_ppo_rnd_config = dict(
),
policy=dict(
cuda=False,
on_policy=False,
model=dict(
obs_shape=4,
action_shape=2,
......@@ -48,7 +49,7 @@ cartpole_ppo_rnd_create_config = dict(
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo'),
policy=dict(type='ppo_offpolicy'),
reward_model=dict(type='rnd'),
)
cartpole_ppo_rnd_create_config = EasyDict(cartpole_ppo_rnd_create_config)
......
......@@ -3,7 +3,7 @@ import gym
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, NaiveReplayBuffer
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPOPolicy
from ding.model import VAC
......@@ -17,14 +17,7 @@ def wrapped_cartpole_env():
def main(cfg, seed=0, max_iterations=int(1e10)):
cfg = compile_config(
cfg,
BaseEnvManager,
PPOPolicy,
BaseLearner,
SampleCollector,
BaseSerialEvaluator,
NaiveReplayBuffer,
save_cfg=True
cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleCollector, BaseSerialEvaluator, save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(collector_env_num)], cfg=cfg.env.manager)
......@@ -44,7 +37,6 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
evaluator = BaseSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = NaiveReplayBuffer(cfg.policy.other.replay_buffer, exp_name=cfg.exp_name)
for _ in range(max_iterations):
if evaluator.should_eval(learner.train_iter):
......@@ -52,13 +44,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
if stop:
break
new_data = collector.collect(train_iter=learner.train_iter)
assert all([len(c) == 0 for c in collector._traj_buffer.values()])
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
for i in range(cfg.policy.learn.update_per_collect):
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is not None:
learner.train(train_data, collector.envstep)
replay_buffer.clear()
learner.train(new_data, collector.envstep)
if __name__ == "__main__":
......
import os
import gym
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, NaiveReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPOOffPolicy
from ding.model import VAC
from ding.utils import set_pkg_seed, deep_merge_dicts
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config
def wrapped_cartpole_env():
return DingEnvWrapper(gym.make('CartPole-v0'))
def main(cfg, seed=0, max_iterations=int(1e10)):
cfg = compile_config(
cfg,
BaseEnvManager,
PPOOffPolicy,
BaseLearner,
SampleCollector,
BaseSerialEvaluator,
NaiveReplayBuffer,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(collector_env_num)], cfg=cfg.env.manager)
evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
model = VAC(**cfg.policy.model)
policy = PPOOffPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = BaseSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = NaiveReplayBuffer(cfg.policy.other.replay_buffer, exp_name=cfg.exp_name)
for _ in range(max_iterations):
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
new_data = collector.collect(train_iter=learner.train_iter)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
for i in range(cfg.policy.learn.update_per_collect):
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is not None:
learner.train(train_data, collector.envstep)
if __name__ == "__main__":
main(cartpole_ppo_offpolicy_config)
......@@ -5,7 +5,7 @@ from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, NaiveReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPOPolicy
from ding.policy import PPOOffPolicy
from ding.model import VAC
from ding.utils import set_pkg_seed, deep_merge_dicts
from ding.reward_model import RndRewardModel
......@@ -20,7 +20,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
cfg = compile_config(
cfg,
BaseEnvManager,
PPOPolicy,
PPOOffPolicy,
BaseLearner,
SampleCollector,
BaseSerialEvaluator,
......@@ -37,7 +37,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
policy = PPOOffPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleCollector(
......@@ -55,7 +55,6 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
if stop:
break
new_data = collector.collect(train_iter=learner.train_iter)
assert all([len(c) == 0 for c in collector._traj_buffer.values()])
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
reward_model.collect_data(new_data)
reward_model.train()
......@@ -65,7 +64,6 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
reward_model.estimate(train_data)
if train_data is not None:
learner.train(train_data, collector.envstep)
replay_buffer.clear()
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册