提交 8dae34cb 编写于 作者: N niuyazhe

refactor(nyz): add basic new sil entry(ci skip)

上级 2f079e02
......@@ -6,7 +6,7 @@ from functools import partial
from tensorboardX import SummaryWriter
from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory, create_sil
......@@ -58,7 +58,12 @@ def serial_pipeline_sil(
# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
base_learner = BaseLearner(
cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name, instance_name='base_learner'
)
sil_learner = BaseLearner(
cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name, instance_name='sil_learner'
)
collector = create_serial_collector(
cfg.policy.collect.collector,
env=collector_env,
......@@ -71,64 +76,67 @@ def serial_pipeline_sil(
)
replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
commander = BaseSerialCommander(
cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
cfg.policy.other.commander, base_learner, collector, evaluator, replay_buffer, policy.command_mode
)
# ==========
# Main loop
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
new_ptr = old_ptr = 0
base_learner.call_hook('before_run')
# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(policy.collect_mode)
if cfg.policy.get('transition_with_policy_data', False):
collector.reset_policy(policy.collect_mode)
else:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
new_data = collector.collect(n_episode=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)
new_ptr += len(new_data)
for _ in range(max_iterations):
collect_kwargs = commander.step()
# Evaluate policy performance
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if evaluator.should_eval(base_learner.train_iter):
stop, reward = evaluator.eval(base_learner.save_checkpoint, base_learner.train_iter, collector.envstep)
if stop:
break
# Collect data by default config n_sample/n_episode
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
new_data = collector.collect(train_iter=base_learner.train_iter, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
new_ptr += len(new_data)
# Learn policy from collected data
for i in range(cfg.policy.learn.update_per_collect):
# Learner will train ``update_per_collect`` times in one iteration.
if cfg.policy.on_policy:
train_data_base_policy = replay_buffer.sample(
learner.policy.get_attribute('batch_size'),
learner.train_iter,
sample_range=slice(old_ptr - new_ptr, None)
)
else:
if cfg.policy.on_policy:
train_data_base_policy = replay_buffer.sample(
base_learner.policy.get_attribute('batch_size'),
base_learner.train_iter,
sample_range=slice(-len(new_data), None)
)
base_learner.train(train_data_base_policy)
else:
for i in range(cfg.policy.learn.update_per_collect):
# Learner will train ``update_per_collect`` times in one iteration.
train_data_base_policy = replay_buffer.sample(
learner.policy.get_attribute('batch_size'), learner.train_iter
base_learner.policy.get_attribute('batch_size'), base_learner.train_iter
)
train_data_sil = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data_base_policy is None or train_data_sil is None:
base_learner.train(train_data_base_policy)
if base_learner.policy.get_attribute('priority'):
replay_buffer.update(base_learner.priority_info)
for i in range(cfg.policy.other.sil.update_per_collect):
train_data_sil = replay_buffer.sample(
cfg.policy.other.sil.n_episode_per_train, sil_learner.train_iter, groupby='episode'
)
train_data_sil = policy.process_sil_data(train_data_sil)
if train_data_sil is None:
# It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
logging.warning(
"Replay buffer's data can only train for {} steps. ".format(i) +
"Replay buffer's data can only train for sil {} steps. ".format(i) +
"You can modify data collect config, e.g. increasing n_sample, n_episode."
)
break
learner.train({'base_policy': train_data_base_policy, 'sil': train_data_sil}, collector.envstep)
if learner.policy.get_attribute('priority'):
replay_buffer.update(learner.priority_info)
if cfg.policy.on_policy:
# On-policy algorithm must clear the replay buffer.
# replay_buffer.clear()
old_ptr = new_ptr
sil_learner.train(train_data_sil)
# Learner's after_run hook.
learner.call_hook('after_run')
base_learner.call_hook('after_run')
return policy
......@@ -320,7 +320,7 @@ class EpsGreedyMultinomialSampleWrapper(IModelWrapper):
def forward(self, *args, **kwargs):
eps = kwargs.pop('eps')
alpha = kwargs.pop('alpha')
alpha = kwargs.pop('alpha', 1)
output = self._model.forward(*args, **kwargs)
assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
logit = output['logit']
......
......@@ -170,7 +170,7 @@ class A2CPolicy(Policy):
self._gamma = self._cfg.collect.discount_factor
self._gae_lambda = self._cfg.collect.gae_lambda
def _forward_collect(self, data: dict) -> dict:
def _forward_collect(self, data: dict, **kwargs) -> dict:
r"""
Overview:
Forward function of collect mode.
......@@ -188,7 +188,7 @@ class A2CPolicy(Policy):
data = to_device(data, self._device)
self._collect_model.eval()
with torch.no_grad():
output = self._collect_model.forward(data, mode='compute_actor_critic')
output = self._collect_model.forward(data, mode='compute_actor_critic', **kwargs)
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
......
......@@ -17,12 +17,14 @@ def create_sil(policy: Policy, cfg):
return sil_policy
class SIL(Policy):
class SILPolicy(Policy):
r"""
Overview:
Policy class of SIL algorithm.
"""
sil_config = dict(
config = dict(
update_per_collect=10,
n_episode_per_train=4,
value_weight=0.5,
learning_rate=0.001,
betas=(0.9, 0.999),
......@@ -32,8 +34,8 @@ class SIL(Policy):
def __init__(self, policy: Policy, cfg):
self.base_policy = policy
self._model = policy._model
cfg.policy.other.sil = deep_merge_dicts(self.sil_config, cfg.policy.other.sil)
super(SIL, self).__init__(cfg.policy, model=policy._model, enable_field=policy._enable_field)
cfg.policy.other.sil = deep_merge_dicts(self.config, cfg.policy.other.sil)
super(SILPolicy, self).__init__(cfg.policy, model=policy._model, enable_field=policy._enable_field)
def _init_learn(self) -> None:
r"""
......@@ -184,5 +186,5 @@ class SIL(Policy):
]
class SILCommand(SIL, DummyCommandModePolicy):
class SILCommand(SILPolicy, DummyCommandModePolicy):
pass
from ding.entry.serial_entry_onpolicy import serial_pipeline_onpolicy
from easydict import EasyDict
lunarlander_a2c_config = dict(
exp_name='lunarlander_a2c_seed0',
env=dict(
collector_env_num=4,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=200,
),
policy=dict(
cuda=False,
model=dict(
obs_shape=8,
action_shape=4,
encoder_hidden_size_list=[128, 64],
share_encoder=False,
),
learn=dict(
batch_size=64,
# (bool) Whether to normalize advantage. Default to False.
adv_norm=False,
learning_rate=0.001,
# (float) loss weight of the value network, the weight of policy network is set to 1
value_weight=0.1,
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
entropy_weight=0.00001,
),
collect=dict(
# (int) collect n_sample data, train model n_iteration times
n_sample=64,
# (float) the trade-off factor lambda to balance 1step td and mc
gae_lambda=0.95,
discount_factor=0.995,
),
),
)
lunarlander_a2c_config = EasyDict(lunarlander_a2c_config)
main_config = lunarlander_a2c_config
lunarlander_a2c_create_config = dict(
env=dict(
type='lunarlander',
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='a2c'),
)
lunarlander_a2c_create_config = EasyDict(lunarlander_a2c_create_config)
create_config = lunarlander_a2c_create_config
if __name__ == '__main__':
serial_pipeline_onpolicy((main_config, create_config), seed=0)
......@@ -2,47 +2,45 @@ from ding.entry.serial_entry_sil import serial_pipeline_sil
from easydict import EasyDict
lunarlander_a2c_config = dict(
exp_name='lunarlander_a2c',
exp_name='lunarlander_a2c_sil_seed0',
env=dict(
collector_env_num=8,
collector_env_num=4,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=195,
stop_value=200,
),
policy=dict(
on_policy=True,
cuda=False,
# (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
model=dict(
obs_shape=8,
action_shape=4,
encoder_hidden_size_list=[512, 64],
encoder_hidden_size_list=[128, 64],
share_encoder=False,
),
learn=dict(
batch_size=64,
# (bool) Whether to normalize advantage. Default to False.
unroll_len=1,
normalize_advantage=False,
adv_norm=False,
learning_rate=0.001,
# (float) loss weight of the value network, the weight of policy network is set to 1
value_weight=0.5,
value_weight=0.1,
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
entropy_weight=1.0,
entropy_weight=0.00001,
),
collect=dict(
collector=dict(
type='episode',
get_train_sample=True,
),
# (int) collect n_sample data, train model n_iteration times
n_episode=8,
n_sample=64,
# (float) the trade-off factor lambda to balance 1step td and mc
gae_lambda=0.95,
discount_factor=0.995,
),
other=dict(sil=dict(
value_weight=0.5,
learning_rate=0.0001,
), replay_buffer=dict(replay_buffer_size=200000, )),
other=dict(
replay_buffer=dict(replay_buffer_size=100000, ),
sil=dict(
update_per_collect=10,
n_episode_per_train=8,
),
)
),
)
lunarlander_a2c_config = EasyDict(lunarlander_a2c_config)
......@@ -55,11 +53,10 @@ lunarlander_a2c_create_config = dict(
),
env_manager=dict(type='subprocess'),
policy=dict(type='a2c'),
replay_buffer=dict(type='deque'),
)
lunarlander_a2c_create_config = EasyDict(lunarlander_a2c_create_config)
create_config = lunarlander_a2c_create_config
if __name__ == '__main__':
from ding.entry import serial_entry_sil
serial_pipeline_sil((main_config, create_config), seed=0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册