未验证 提交 9929dc37 编写于 作者: W Will-Nie 提交者: GitHub

SQIL (#25)

* add sqil

* conceal all the personal info

* revise according to the comments

* correct_format

* add_comment to hardcodes part

* pass flake8

* add force_reproducibility = True; device, ex_model

* check format
上级 9bc39314
......@@ -52,7 +52,7 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.option(
'-m',
'--mode',
type=click.Choice(['serial', 'parallel', 'dist', 'eval']),
type=click.Choice(['serial', 'serial_sqil', 'parallel', 'dist', 'eval']),
help='serial-train or parallel-train or dist-train or eval'
)
@click.option('-c', '--config', type=str, help='Path to DRL experiment config')
......@@ -119,6 +119,13 @@ def cli(
if config is None:
config = get_predefined_config(env, policy)
serial_pipeline(config, seed, max_iterations=train_iter)
if mode == 'serial_sqil':
if config == 'lunarlander_sqil_config.py' or 'cartpole_sqil_config.py' or 'pong_sqil_config.py' \
or 'spaceinvaders_sqil_config.py' or 'qbert_sqil_config.py':
from .serial_entry_sqil import serial_pipeline_sqil
if config is None:
config = get_predefined_config(env, policy)
serial_pipeline_sqil(config, seed, max_iterations=train_iter)
elif mode == 'parallel':
from .parallel_entry import parallel_pipeline
parallel_pipeline(config, seed, enable_total_log, disable_flask_log)
......
from ding.policy.base_policy import Policy
from typing import Union, Optional, List, Any, Tuple
import os
import torch
import logging
from functools import partial
from tensorboardX import SummaryWriter
from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
from ding.utils import set_pkg_seed
from ding.model import DQN
def serial_pipeline_sqil(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
expert_model: Optional[torch.nn.Module] = None,
max_iterations: Optional[int] = int(1e10),
) -> 'Policy': # noqa
"""
Overview:
Serial pipeline sqil entry: we create this serial pipeline in order to\
implement SQIL in DI-engine. For now, we support the following envs\
Cartpole, Lunarlander, Pong, Spaceinvader, Qbert. The demonstration\
data come from the expert model. We use a well-trained model to \
generate demonstration data online
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- expert_model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.\
The default model is DQN(**cfg.policy.model)
- max_iterations (:obj:`Optional[torch.nn.Module]`): Learner's max iteration. Pipeline will stop \
when reaching this iteration.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = input_cfg
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
else:
env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
expert_collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
expert_collector_env.seed(cfg.seed)
collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
#expert_model = DQN(**cfg.policy.model)
expert_policy = create_policy(cfg.policy, model=expert_model, enable_field=['collect'])
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
#model = DQN(**cfg.policy.model)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
expert_policy.collect_mode.load_state_dict(
torch.load(cfg.policy.collect.demonstration_info_path, map_location='cpu')
)
# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = create_serial_collector(
cfg.policy.collect.collector,
env=collector_env,
policy=policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name
)
expert_collector = create_serial_collector(
cfg.policy.collect.collector,
env=expert_collector_env,
policy=expert_policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name
)
evaluator = BaseSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
expert_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
commander = BaseSerialCommander(
cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
)
# ==========
# Main loop
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)
for _ in range(max_iterations):
collect_kwargs = commander.step()
# Evaluate policy performance
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Collect data by default config n_sample/n_episode
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
expert_data = expert_collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': -1})
for i in range(len(new_data)):
device_1 = new_data[i]['obs'].device
device_2 = expert_data[i]['obs'].device
new_data[i]['reward'] = torch.zeros(cfg.policy.nstep).to(device_1)
expert_data[i]['reward'] = torch.ones(cfg.policy.nstep).to(device_2)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
expert_buffer.push(expert_data, cur_collector_envstep=collector.envstep)
# Learn policy from collected data
for i in range(cfg.policy.learn.update_per_collect):
# Learner will train ``update_per_collect`` times in one iteration.
train_data = replay_buffer.sample((learner.policy.get_attribute('batch_size')) // 2, learner.train_iter)
train_data_demonstration = expert_buffer.sample(
(learner.policy.get_attribute('batch_size')) // 2, learner.train_iter
)
if train_data is None:
# It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
logging.warning(
"Replay buffer's data can only train for {} steps. ".format(i) +
"You can modify data collect config, e.g. increasing n_sample, n_episode."
)
break
train_data = train_data + train_data_demonstration
learner.train(train_data, collector.envstep)
if learner.policy.get_attribute('priority'):
replay_buffer.update(learner.priority_info)
if cfg.policy.on_policy:
# On-policy algorithm must clear the replay buffer.
replay_buffer.clear()
# Learner's after_run hook.
learner.call_hook('after_run')
return policy
......@@ -5,6 +5,7 @@ import numpy as np
import torch
from ding.torch_utils import get_tensor_data
from ding.rl_utils import create_noise_generator
from torch.distributions import Categorical
class IModelWrapper(ABC):
......@@ -241,6 +242,50 @@ class EpsGreedySampleWrapper(IModelWrapper):
return output
class EpsGreedySampleWrapperSql(IModelWrapper):
r"""
Overview:
Epsilon greedy sampler used in collector_model to help balance exploratin and exploitation.
Interfaces:
register
"""
def forward(self, *args, **kwargs):
eps = kwargs.pop('eps')
alpha = kwargs.pop('alpha')
output = self._model.forward(*args, **kwargs)
assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
logit = output['logit']
assert isinstance(logit, torch.Tensor) or isinstance(logit, list)
if isinstance(logit, torch.Tensor):
logit = [logit]
if 'action_mask' in output:
mask = output['action_mask']
if isinstance(mask, torch.Tensor):
mask = [mask]
logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)]
else:
mask = None
action = []
for i, l in enumerate(logit):
if np.random.random() > eps:
prob = torch.softmax(output['logit'] / alpha, dim=-1)
prob = prob / torch.sum(prob, 1, keepdims=True)
pi_action = torch.zeros(prob.shape)
pi_action = Categorical(prob)
pi_action = pi_action.sample()
action.append(pi_action)
else:
if mask:
action.append(sample_action(prob=mask[i].float()))
else:
action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1]))
if len(action) == 1:
action, logit = action[0], logit[0]
output['action'] = action
return output
class ActionNoiseWrapper(IModelWrapper):
r"""
Overview:
......@@ -376,6 +421,7 @@ wrapper_name_map = {
'hidden_state': HiddenStateWrapper,
'argmax_sample': ArgmaxSampleWrapper,
'eps_greedy_sample': EpsGreedySampleWrapper,
'eps_greedy_sample_sql': EpsGreedySampleWrapperSql,
'multinomial_sample': MultinomialSampleWrapper,
'action_noise': ActionNoiseWrapper,
# model wrapper
......
......@@ -22,6 +22,7 @@ from .coma import COMAPolicy
from .atoc import ATOCPolicy
from .acer import ACERPolicy
from .qtran import QTRANPolicy
from .sql import SQLPolicy
class EpsCommandModePolicy(CommandModePolicy):
......@@ -107,6 +108,11 @@ class SQNCommandModePolicy(SQNPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('sql_command')
class SQLCommandModePolicy(SQLPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('ppo_command')
class PPOCommandModePolicy(PPOPolicy, DummyCommandModePolicy):
pass
......
from typing import List, Dict, Any, Tuple, Union, Optional
from collections import namedtuple, deque
import copy
import torch
from torch.distributions import Categorical
import logging
from easydict import EasyDict
from ding.torch_utils import Adam, to_device
from ding.utils.data import default_collate, default_decollate
from ding.rl_utils import q_nstep_td_data, q_nstep_sql_td_error, get_nstep_return_data, get_train_sample
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY
from .base_policy import Policy
from .common_utils import default_preprocess_learn
@POLICY_REGISTRY.register('sql')
class SQLPolicy(Policy):
r"""
Overview:
Policy class of SQL algorithm.
"""
config = dict(
# (str) RL policy register name (refer to function "POLICY_REGISTRY").
type='sql',
# (bool) Whether to use cuda for network.
cuda=False,
# (bool) Whether the RL algorithm is on-policy or off-policy.
on_policy=False,
# (bool) Whether use priority(priority sample, IS weight, update priority)
priority=False,
# (float) Reward's future discount factor, aka. gamma.
discount_factor=0.97,
# (int) N-step reward for target q_value estimation
nstep=1,
learn=dict(
# (bool) Whether to use multi gpu
multi_gpu=False,
# How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=3, # after the batch data come into the learner, train with the data for 3 times
batch_size=64,
learning_rate=0.001,
# ==============================================================
# The following configs are algorithm-specific
# ==============================================================
# (int) Frequence of target network update.
target_update_freq=100,
# (bool) Whether ignore done(usually for max step termination env)
ignore_done=False,
alpha=0.1,
),
# collect_mode config
collect=dict(
# (int) Only one of [n_sample, n_step, n_episode] shoule be set
#n_sample=8, # collect 8 samples and put them in collector
# (int) Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
),
eval=dict(),
# other config
other=dict(
# Epsilon greedy with decay.
eps=dict(
# (str) Decay type. Support ['exp', 'linear'].
type='exp',
start=0.95,
end=0.1,
# (int) Decay length(env step)
decay=10000,
),
replay_buffer=dict(replay_buffer_size=10000, )
),
)
def _init_learn(self) -> None:
r"""
Overview:
Learn mode init method. Called by ``self.__init__``.
Init the optimizer, algorithm config, main and target models.
"""
self._priority = self._cfg.priority
# Optimizer
self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
self._gamma = self._cfg.discount_factor
self._nstep = self._cfg.nstep
self._alpha = self._cfg.learn.alpha
# use wrapper instead of plugin
self._target_model = copy.deepcopy(self._model)
self._target_model = model_wrap(
self._target_model,
wrapper_name='target',
update_type='assign',
update_kwargs={'freq': self._cfg.learn.target_update_freq}
)
self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
self._learn_model.reset()
self._target_model.reset()
def _forward_learn(self, data: dict) -> Dict[str, Any]:
r"""
Overview:
Forward and backward function of learn mode.
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
"""
data = default_preprocess_learn(
data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True
)
if self._cuda:
data = to_device(data, self._device)
# ====================
# Q-learning forward
# ====================
self._learn_model.train()
self._target_model.train()
# Current q value (main model)
q_value = self._learn_model.forward(data['obs'])['logit']
with torch.no_grad():
# Target q value
target_q_value = self._target_model.forward(data['next_obs'])['logit']
# Max q value action (main model)
target_q_action = self._learn_model.forward(data['next_obs'])['action']
data_n = q_nstep_td_data(
q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight']
)
value_gamma = data.get('value_gamma')
loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(
data_n, self._gamma, self._cfg.learn.alpha, nstep=self._nstep, value_gamma=value_gamma
)
record_target_v = record_target_v.mean()
# ====================
# Q-learning update
# ====================
self._optimizer.zero_grad()
loss.backward()
if self._cfg.learn.multi_gpu:
self.sync_gradients(self._learn_model)
self._optimizer.step()
# =============
# after update
# =============
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': loss.item(),
'priority': td_error_per_sample.abs().tolist(),
'record_value_function': record_target_v
# Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard.
# '[histogram]action_distribution': data['action'],
}
def _state_dict_learn(self) -> Dict[str, Any]:
return {
'model': self._learn_model.state_dict(),
'optimizer': self._optimizer.state_dict(),
}
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
self._learn_model.load_state_dict(state_dict['model'])
self._optimizer.load_state_dict(state_dict['optimizer'])
def _init_collect(self) -> None:
r"""
Overview:
Collect mode init method. Called by ``self.__init__``.
Init traj and unroll length, adder, collect model.
Enable the eps_greedy_sample
"""
self._unroll_len = self._cfg.collect.unroll_len
self._gamma = self._cfg.discount_factor # necessary for parallel
self._nstep = self._cfg.nstep # necessary for parallel
self._info = self._cfg.collect.demonstration_info_path
self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample_sql')
self._collect_model.reset()
def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
r"""
Overview:
Forward function for collect mode with eps_greedy
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs'].
Returns:
- data (:obj:`dict`): The collected data
"""
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
self._collect_model.eval()
with torch.no_grad():
output = self._collect_model.forward(data, eps=eps, alpha=self._cfg.learn.alpha)
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Overview:
For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \
can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \
or some continuous transitions(DRQN).
Arguments:
- data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \
format as the return value of ``self._process_transition`` method.
Returns:
- samples (:obj:`dict`): The list of training samples.
.. note::
We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \
And the user can customize the this data processing procecure by overriding this two methods and collector \
itself.
"""
data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)
return get_train_sample(data, self._unroll_len)
def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
r"""
Overview:
Generate dict type transition data from inputs.
Arguments:
- obs (:obj:`Any`): Env observation
- model_output (:obj:`dict`): Output of collect model, including at least ['action']
- timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
(here 'obs' indicates obs after env step).
Returns:
- transition (:obj:`dict`): Dict type transition data.
"""
transition = {
'obs': obs,
'next_obs': timestep.obs,
'action': model_output['action'],
'reward': timestep.reward,
'done': timestep.done,
}
return EasyDict(transition)
def _init_eval(self) -> None:
r"""
Overview:
Evaluate mode init method. Called by ``self.__init__``.
Init eval model with argmax strategy.
"""
self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
self._eval_model.reset()
def _forward_eval(self, data: dict) -> dict:
r"""
Overview:
Forward function for eval mode, similar to ``self._forward_collect``.
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs'].
Returns:
- data (:obj:`dict`): Dict type data, including at least inferred action according to input obs.
"""
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
self._eval_model.eval()
with torch.no_grad():
output = self._eval_model.forward(data)
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
def default_model(self) -> Tuple[str, List[str]]:
"""
Overview:
Return this algorithm default model setting for demonstration.
Returns:
- model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names
.. note::
The user can define and use customized network model but must obey the same inferface definition indicated \
by import_names path. For DQN, ``ding.model.template.q_learning.DQN``
"""
return 'dqn', ['ding.model.template.q_learning']
def _monitor_vars_learn(self) -> List[str]:
return super()._monitor_vars_learn() + ['record_value_function']
......@@ -8,7 +8,8 @@ from .coma import coma_data, coma_error
from .td import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_error, td_lambda_data, td_lambda_error,\
q_nstep_td_error_with_rescale, v_1step_td_data, v_1step_td_error, v_nstep_td_data, v_nstep_td_error, \
generalized_lambda_returns, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_error, dist_nstep_td_data, \
nstep_return_data, nstep_return, iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error
nstep_return_data, nstep_return, iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error,\
q_nstep_sql_td_error
from .vtrace import vtrace_loss, compute_importance_weights
from .upgo import upgo_loss
from .adder import get_gae, get_gae_with_default_last_value, get_nstep_return_data, get_train_sample
......
......@@ -6,7 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ding.rl_utils.value_rescale import value_transform, value_inv_transform
from ding.hpc_rl import hpc_wrapper
import copy
q_1step_td_data = namedtuple('q_1step_td_data', ['q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight'])
......@@ -470,6 +470,69 @@ def qrdqn_nstep_td_error(
return (loss * weight).mean(), loss
def q_nstep_sql_td_error(
data: namedtuple,
gamma: float,
alpha: float,
nstep: int = 1,
cum_reward: bool = False,
value_gamma: Optional[torch.Tensor] = None,
criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
) -> torch.Tensor:
"""
Overview:
Multistep (1 step or n step) td_error for q-learning based algorithm
Arguments:
- data (:obj:`q_nstep_td_data`): the input data, q_nstep_sql_td_data to calculate loss
- gamma (:obj:`float`): discount factor
- Alpha (:obj:`float`): A parameter to weight entropy term in a policy equation
- cum_reward (:obj:`bool`): whether to use cumulative nstep reward, which is figured out when collecting data
- value_gamma (:obj:`torch.Tensor`): gamma discount value for target soft_q_value
- criterion (:obj:`torch.nn.modules`): loss function criterion
- nstep (:obj:`int`): nstep num, default set to 1
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
- td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): the q_nstep_td_data containing\
['q', 'next_n_q', 'action', 'reward', 'done']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`
- action (:obj:`torch.LongTensor`): :math:`(B, )`
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`
"""
q, next_n_q, action, next_n_action, reward, done, weight = data
assert len(action.shape) == 1, action.shape
if weight is None:
weight = torch.ones_like(action)
batch_range = torch.arange(action.shape[0])
q_s_a = q[batch_range, action]
# target_q_s_a = next_n_q[batch_range, next_n_action]
target_v = alpha * torch.logsumexp(
next_n_q / alpha, 1
) # target_v = alpha * torch.log(torch.sum(torch.exp(next_n_q / alpha), 1))
target_v[target_v == float("Inf")] = 20
target_v[target_v == float("-Inf")] = -20
# For an appropriate hyper-parameter alpha, these hardcodes can be removed.
# However, algorithms may face the danger of explosion for other alphas.
# The hardcodes above are to prevent this situation from happening
record_target_v = copy.deepcopy(target_v)
#print(target_v)
if cum_reward:
if value_gamma is None:
target_v = reward + (gamma ** nstep) * target_v * (1 - done)
else:
target_v = reward + value_gamma * target_v * (1 - done)
else:
target_v = nstep_return(nstep_return_data(reward, target_v, done), gamma, nstep, value_gamma)
td_error_per_sample = criterion(q_s_a, target_v.detach())
return (td_error_per_sample * weight).mean(), td_error_per_sample, record_target_v
iqn_nstep_td_data = namedtuple(
'iqn_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'replay_quantiles', 'weight']
)
......
from copy import deepcopy
from ding.entry import serial_pipeline
from easydict import EasyDict
pong_sqil_config = dict(
exp_name='pong_sqil',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=20,
env_id='PongNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=False, )
),
policy=dict(
cuda=False,
priority=False,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
alpha = 0.12
),
collect=dict(n_sample=96, demonstration_info_path = 'path'), #Users should add their own path here (path should lead to a well-trained model)
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=250000,
),
replay_buffer=dict(replay_buffer_size=100000, ),
),
),
)
pong_sqil_config = EasyDict(pong_sqil_config)
main_config = pong_sqil_config
pong_sqil_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='base', force_reproducibility = True),
policy=dict(type='sql'),
)
pong_sqil_create_config = EasyDict(pong_sqil_create_config)
create_config = pong_sqil_create_config
if __name__ == '__main__':
serial_pipeline((main_config, create_config), seed=0)
from copy import deepcopy
from ding.entry import serial_pipeline
from easydict import EasyDict
pong_sql_config = dict(
exp_name='pong_sql',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=20,
env_id='PongNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=False, )
),
policy=dict(
cuda=False,
priority=False,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
alpha = 0.12
),
collect=dict(n_sample=96, demonstration_info_path = None),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=250000,
),
replay_buffer=dict(replay_buffer_size=100000, ),
),
),
)
pong_sql_config = EasyDict(pong_sql_config)
main_config = pong_sql_config
pong_sql_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='base', force_reproducibility = True),
policy=dict(type='sql'),
)
pong_sql_create_config = EasyDict(pong_sql_create_config)
create_config = pong_sql_create_config
if __name__ == '__main__':
serial_pipeline((main_config, create_config), seed=0)
from copy import deepcopy
from ding.entry import serial_pipeline
from easydict import EasyDict
qbert_dqn_config = dict(
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=30000,
env_id='QbertNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=False, )
),
policy=dict(
cuda=True,
priority=False,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
),
collect=dict(n_sample=100, demonstration_info_path = 'path'), #Users should add their own path here (path should lead to a well-trained model)
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=1000000,
),
replay_buffer=dict(replay_buffer_size=400000, ),
),
),
)
qbert_dqn_config = EasyDict(qbert_dqn_config)
main_config = qbert_dqn_config
qbert_dqn_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess', force_reproducibility = True),
policy=dict(type='dqn'),
)
qbert_dqn_create_config = EasyDict(qbert_dqn_create_config)
create_config = qbert_dqn_create_config
if __name__ == '__main__':
serial_pipeline((main_config, create_config), seed=0)
from copy import deepcopy
from ding.entry import serial_pipeline
from easydict import EasyDict
qbert_dqn_config = dict(
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=30000,
env_id='QbertNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=False, )
),
policy=dict(
cuda=True,
priority=False,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
),
collect=dict(n_sample=100, demonstration_info_path = None),
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=1000000,
),
replay_buffer=dict(replay_buffer_size=400000, ),
),
),
)
qbert_dqn_config = EasyDict(qbert_dqn_config)
main_config = qbert_dqn_config
qbert_dqn_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess', force_reproducibility = True),
policy=dict(type='sql'),
)
qbert_dqn_create_config = EasyDict(qbert_dqn_create_config)
create_config = qbert_dqn_create_config
if __name__ == '__main__':
serial_pipeline((main_config, create_config), seed=0)
from copy import deepcopy
from ding.entry import serial_pipeline
from easydict import EasyDict
space_invaders_sqil_config = dict(
exp_name='space_invaders_sqil',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=10000000000,
env_id='SpaceInvadersNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=False, )
),
policy=dict(
cuda=False,
priority=False,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
alpha = 0.1
),
collect=dict(n_sample=100, demonstration_info_path = 'path'), #Users should add their own path here (path should lead to a well-trained model)
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=1000000,
),
replay_buffer=dict(replay_buffer_size=400000, ),
),
),
)
space_invaders_sqil_config = EasyDict(space_invaders_sqil_config)
main_config = space_invaders_sqil_config
space_invaders_sqil_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='base', force_reproducibility = True),
policy=dict(type='sql'),
)
space_invaders_sqil_create_config = EasyDict(space_invaders_sqil_create_config)
create_config = space_invaders_sqil_create_config
if __name__ == '__main__':
serial_pipeline((main_config, create_config), seed=0)
from copy import deepcopy
from ding.entry import serial_pipeline
from easydict import EasyDict
space_invaders_sql_config = dict(
exp_name='space_invaders_sql',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=10000000000,
env_id='SpaceInvadersNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=False, )
),
policy=dict(
cuda=False,
priority=False,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
alpha = 0.1
),
collect=dict(n_sample=100, demonstration_info_path = None),
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=1000000,
),
replay_buffer=dict(replay_buffer_size=400000, ),
),
),
)
space_invaders_sql_config = EasyDict(space_invaders_sql_config)
main_config = space_invaders_sql_config
space_invaders_sql_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='base', force_reproducibility = True),
policy=dict(type='sql'),
)
space_invaders_sql_create_config = EasyDict(space_invaders_sql_create_config)
create_config = space_invaders_sql_create_config
if __name__ == '__main__':
serial_pipeline((main_config, create_config), seed=0)
from easydict import EasyDict
from ding.entry import serial_pipeline
lunarlander_sqil_config = dict(
exp_name='lunarlander_sqil',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, ),
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=200,
),
policy=dict(
cuda=False,
model=dict(
obs_shape=8,
action_shape=4,
encoder_hidden_size_list=[128, 128, 64],
dueling=True,
),
nstep=1,
discount_factor=0.97,
learn=dict(
batch_size=64,
learning_rate=0.001,
alpha = 0.08
),
collect=dict(
n_sample=64,
demonstration_info_path='path', #Users should add their own path here (path should lead to a well-trained model)
# Cut trajectories into pieces with length "unrol_len".
unroll_len=1,
),
eval=dict(evaluator=dict(eval_freq=50, )), #note: this is the times after which you learns to evaluate
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=10000,
),
replay_buffer=dict(replay_buffer_size=20000, ),
),
),
)
lunarlander_sqil_config = EasyDict(lunarlander_sqil_config)
main_config = lunarlander_sqil_config
lunarlander_sqil_create_config = dict(
env=dict(
type='lunarlander',
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='base', force_reproducibility = True),
policy=dict(type='sql'),
)
lunarlander_sqil_create_config = EasyDict(lunarlander_sqil_create_config)
create_config = lunarlander_sqil_create_config
if __name__ == "__main__":
serial_pipeline([main_config, create_config], seed=0)
\ No newline at end of file
from easydict import EasyDict
lunarlander_sql_config = dict(
exp_name='lunarlander_sql',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=200,
),
policy=dict(
cuda=False,
model=dict(
obs_shape=8,
action_shape=4,
encoder_hidden_size_list=[128, 128, 64],
dueling=True,
),
nstep=1,
discount_factor=0.97,
learn=dict(
batch_size=64,
learning_rate=0.001,
alpha = 0.08
),
collect=dict(n_sample=64, demonstration_info_path=None),
eval=dict(evaluator=dict(eval_freq=50, )), #note: this is the times after which you learns to evaluate
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=10000,
),
replay_buffer=dict(replay_buffer_size=20000, ),
),
),
)
lunarlander_sql_config = EasyDict(lunarlander_sql_config)
main_config = lunarlander_sql_config
lunarlander_sql_create_config = dict(
env=dict(
type='lunarlander',
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='base', force_reproducibility = True),
policy=dict(type='sql'),
)
lunarlander_sql_create_config = EasyDict(lunarlander_sql_create_config)
create_config = lunarlander_sql_create_config
from easydict import EasyDict
cartpole_sqil_config = dict(
exp_name='cartpole_sqil',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=195,
),
policy=dict(
cuda=False,
model=dict(
obs_shape=4,
action_shape=2,
encoder_hidden_size_list=[128, 128, 64],
dueling=True,
),
nstep=1,
discount_factor=0.97,
learn=dict(
batch_size=64,
learning_rate=0.001,
alpha = 0.12
),
collect=dict(n_sample=8, demonstration_info_path = 'path'), #Users should add their own path here (path should lead to a well-trained model)
eval=dict(evaluator=dict(eval_freq=50, )), #note: this is the times after which you learns to evaluate
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=10000,
),
replay_buffer=dict(replay_buffer_size=20000, ),
),
),
)
cartpole_sqil_config = EasyDict(cartpole_sqil_config)
main_config = cartpole_sqil_config
cartpole_sqil_create_config = dict(
env=dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='base', force_reproducibility = True),
policy=dict(type='sql'),
)
cartpole_sqil_create_config = EasyDict(cartpole_sqil_create_config)
create_config = cartpole_sqil_create_config
from easydict import EasyDict
cartpole_sql_config = dict(
exp_name='cartpole_sql',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=195,
),
policy=dict(
cuda=False,
model=dict(
obs_shape=4,
action_shape=2,
encoder_hidden_size_list=[128, 128, 64],
dueling=True,
),
nstep=1,
discount_factor=0.97,
learn=dict(
batch_size=64,
learning_rate=0.001,
alpha = 0.12
),
collect=dict(n_sample=8, demonstration_info_path = None),
eval=dict(evaluator=dict(eval_freq=50, )), #note: this is the times after which you learns to evaluate
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=10000,
),
replay_buffer=dict(replay_buffer_size=20000, ),
),
),
)
cartpole_sql_config = EasyDict(cartpole_sql_config)
main_config = cartpole_sql_config
cartpole_sql_create_config = dict(
env=dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='base', force_reproducibility = True),
policy=dict(type='sql'),
)
cartpole_sql_create_config = EasyDict(cartpole_sql_create_config)
create_config = cartpole_sql_create_config
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册