未验证 提交 e2ca8738 编写于 作者: W Will-Nie 提交者: GitHub

feature(nyp): add DQfD algorithm (#48)

* add_dqfd

* Is_expert to is_expert

* modify according to the last commnets

* value_gamma; done; marginloss; sqil compatibility

* finally shorten the code, revise config

* revise config, style

* add_readme/two_more_config

* correct format
Co-authored-by: Nniuyazhe <niuyazhe@sensetime.com>
上级 8efee984
......@@ -120,12 +120,13 @@ ding -m serial -e cartpole -p dqn -s 0
| 20 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 |
| 21 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_reward_model -c cartpole_dqn_config.py -s 0 |
| 22 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 |
| 23 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
| 24 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py |
| 25 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
| 26 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 27 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
| 28 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py |
| 23 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![IL](https://img.shields.io/badge/-discrete-brightgreen) | [policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 |
| 24 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
| 25 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py |
| 26 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
| 27 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 28 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
| 29 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py |
![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space, which is only label in normal DRL algorithms(1-15)
......
......@@ -4,5 +4,7 @@ from .serial_entry_onpolicy import serial_pipeline_onpolicy
from .serial_entry_offline import serial_pipeline_offline
from .serial_entry_il import serial_pipeline_il
from .serial_entry_reward_model import serial_pipeline_reward_model
from .serial_entry_dqfd import serial_pipeline_dqfd
from .serial_entry_sqil import serial_pipeline_sqil
from .parallel_entry import parallel_pipeline
from .application_entry import eval, collect_demo_data
......@@ -52,7 +52,7 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.option(
'-m',
'--mode',
type=click.Choice(['serial', 'serial_onpolicy', 'serial_sqil', 'parallel', 'dist', 'eval']),
type=click.Choice(['serial', 'serial_onpolicy', 'serial_sqil', 'serial_dqfd', 'parallel', 'dist', 'eval']),
help='serial-train or parallel-train or dist-train or eval'
)
@click.option('-c', '--config', type=str, help='Path to DRL experiment config')
......@@ -157,6 +157,12 @@ def cli(
config = get_predefined_config(env, policy)
expert_config = input("Enter the name of the config you used to generate your expert model: ")
serial_pipeline_sqil(config, expert_config, seed, max_iterations=train_iter)
elif mode == 'serial_dqfd':
from .serial_entry_dqfd import serial_pipeline_dqfd
if config is None:
config = get_predefined_config(env, policy)
expert_config = input("Enter the name of the config you used to generate your expert model: ")
serial_pipeline_dqfd(config, expert_config, seed, max_iterations=train_iter)
elif mode == 'parallel':
from .parallel_entry import parallel_pipeline
parallel_pipeline(config, seed, enable_total_log, disable_flask_log)
......
此差异已折叠。
import pytest
import torch
from copy import deepcopy
from ding.entry import serial_pipeline
from ding.entry.serial_entry_dqfd import serial_pipeline_dqfd
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
from dizoo.classic_control.cartpole.config.cartpole_dqfd_config import cartpole_dqfd_config, cartpole_dqfd_create_config
@pytest.mark.unittest
def test_dqfd():
expert_policy_state_dict_path = './expert_policy.pth'
config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)]
expert_policy = serial_pipeline(config, seed=0)
torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)
config = [deepcopy(cartpole_dqfd_config), deepcopy(cartpole_dqfd_create_config)]
config[0].policy.collect.demonstration_info_path = expert_policy_state_dict_path
config[0].policy.learn.update_per_collect = 1
try:
serial_pipeline_dqfd(config, [cartpole_dqfd_config, cartpole_dqfd_create_config], seed=0, max_iterations=1)
except Exception:
assert False, "pipeline fail"
......@@ -24,6 +24,7 @@ from .atoc import ATOCPolicy
from .acer import ACERPolicy
from .qtran import QTRANPolicy
from .sql import SQLPolicy
from .dqfd import DQFDPolicy
from .d4pg import D4PGPolicy
from .cql import CQLPolicy, CQLDiscretePolicy
......@@ -81,6 +82,11 @@ class DQNCommandModePolicy(DQNPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('dqfd_command')
class DQFDCommandModePolicy(DQFDPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('c51_command')
class C51CommandModePolicy(C51Policy, EpsCommandModePolicy):
pass
......
from typing import List, Dict, Any, Tuple
from collections import namedtuple
import copy
import torch
from ding.torch_utils import Adam, to_device
from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, get_nstep_return_data, get_train_sample, \
dqfd_nstep_td_error, dqfd_nstep_td_data
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from .dqn import DQNPolicy
from .common_utils import default_preprocess_learn
from copy import deepcopy
@POLICY_REGISTRY.register('dqfd')
class DQFDPolicy(DQNPolicy):
r"""
Overview:
Policy class of DQFD algorithm, extended by Double DQN/Dueling DQN/PER/multi-step TD.
Config:
== ==================== ======== ============== ======================================== =======================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ============== ======================================== =======================
1 ``type`` str dqn | RL policy register name, refer to | This arg is optional,
| registry ``POLICY_REGISTRY`` | a placeholder
2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
| erent from modes
3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
| or off-policy
4 ``priority`` bool True | Whether use priority(PER) | Priority sample,
| update priority
5 | ``priority_IS`` bool True | Whether use Importance Sampling Weight
| ``_weight`` | to correct biased update. If True,
| priority must be True.
6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse
| ``factor`` [0.95, 0.999] | gamma | reward env
7 ``nstep`` int 10, | N-step reward discount sum for target
[3, 5] | q_value estimation
8 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary
| ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
| valid in serial training | means more off-policy
9 | ``learn.batch_`` int 64 | The number of samples of an iteration
| ``size``
10 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
| ``_rate``
11 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update
| ``update_freq``
12 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
| ``done`` | calculation. | fake termination env
13 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
| call of collector. | different envs
14 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
| ``_len``
== ==================== ======== ============== ======================================== =======================
"""
config = dict(
type='dqfd',
cuda=False,
on_policy=False,
priority=True,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=True,
discount_factor=0.99,
nstep=10,
learn=dict(
# multiplicative factor for each loss
lambda1=1.0,
lambda2=1.0,
lambda3=1e-5,
# margin function in JE, here we implement this as a constant
margin_function=0.8,
# number of pertraining iterations
per_train_iter_k=10,
# (bool) Whether to use multi gpu
multi_gpu=False,
# How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=3,
batch_size=64,
learning_rate=0.001,
# ==============================================================
# The following configs are algorithm-specific
# ==============================================================
# (int) Frequence of target network update.
target_update_freq=100,
# (bool) Whether ignore done(usually for max step termination env)
ignore_done=False,
),
# collect_mode config
collect=dict(
# (int) Only one of [n_sample, n_episode] should be set
# n_sample=8,
# (int) Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
# The hyperparameter pho, the demo ratio, control the propotion of data\
# coming from expert demonstrations versus from the agent's own experience.
pho=0.5,
),
eval=dict(),
# other config
other=dict(
# Epsilon greedy with decay.
eps=dict(
# (str) Decay type. Support ['exp', 'linear'].
type='exp',
start=0.95,
end=0.1,
# (int) Decay length(env step)
decay=10000,
),
replay_buffer=dict(replay_buffer_size=10000, ),
),
)
def _init_learn(self) -> None:
"""
Overview:
Learn mode init method. Called by ``self.__init__``, initialize the optimizer, algorithm arguments, main \
and target model.
"""
self.lambda1 = self._cfg.learn.lambda1, # n-step return
self.lambda2 = self._cfg.learn.lambda2, # supervised loss
self.lambda3 = self._cfg.learn.lambda3, # L2
# margin function in JE, here we implement this as a constant
self.margin_function = self._cfg.learn.margin_function
self._priority = self._cfg.priority
self._priority_IS_weight = self._cfg.priority_IS_weight
# Optimizer
self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate, weight_decay=self.lambda3[0])
self._gamma = self._cfg.discount_factor
self._nstep = self._cfg.nstep
# use model_wrapper for specialized demands of different modes
self._target_model = copy.deepcopy(self._model)
self._target_model = model_wrap(
self._target_model,
wrapper_name='target',
update_type='assign',
update_kwargs={'freq': self._cfg.learn.target_update_freq}
)
self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
self._learn_model.reset()
self._target_model.reset()
def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Overview:
Forward computation graph of learn mode(updating policy).
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
np.ndarray or dict/list combinations.
Returns:
- info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
recorded in text log and tensorboard, values are python scalar or a list of scalars.
ArgumentsKeys:
- necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done``
- optional: ``value_gamma``, ``IS``
ReturnsKeys:
- necessary: ``cur_lr``, ``total_loss``, ``priority``
- optional: ``action_distribution``
"""
data = default_preprocess_learn(
data,
use_priority=self._priority,
use_priority_IS_weight=self._cfg.priority_IS_weight,
ignore_done=self._cfg.learn.ignore_done,
use_nstep=True
)
data['done_1'] = data['done_1'].float()
if self._cuda:
data = to_device(data, self._device)
# ====================
# Q-learning forward
# ====================
self._learn_model.train()
self._target_model.train()
# Current q value (main model)
q_value = self._learn_model.forward(data['obs'])['logit']
# Target q value
with torch.no_grad():
target_q_value = self._target_model.forward(data['next_obs'])['logit']
target_q_value_one_step = self._target_model.forward(data['next_obs_1'])['logit']
# Max q value action (main model)
target_q_action = self._learn_model.forward(data['next_obs'])['action']
target_q_action_one_step = self._learn_model.forward(data['next_obs_1'])['action']
data_n = dqfd_nstep_td_data(
q_value,
target_q_value,
data['action'],
target_q_action,
data['reward'],
data['done'],
data['done_1'],
data['weight'],
target_q_value_one_step,
target_q_action_one_step,
data['is_expert'] # set is_expert flag(expert 1, agent 0)
)
value_gamma = data.get('value_gamma')
loss, td_error_per_sample = dqfd_nstep_td_error(
data_n,
self._gamma,
self.lambda1,
self.lambda2,
self.margin_function,
nstep=self._nstep,
value_gamma=value_gamma
)
# ====================
# Q-learning update
# ====================
self._optimizer.zero_grad()
loss.backward()
if self._cfg.learn.multi_gpu:
self.sync_gradients(self._learn_model)
self._optimizer.step()
# =============
# after update
# =============
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': loss.item(),
'priority': td_error_per_sample.abs().tolist(),
# Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard.
# '[histogram]action_distribution': data['action'],
}
def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Overview:
For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \
can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \
or some continuous transitions(DRQN).
Arguments:
- data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \
format as the return value of ``self._process_transition`` method.
Returns:
- samples (:obj:`dict`): The list of training samples.
.. note::
We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \
And the user can customize the this data processing procecure by overriding this two methods and collector \
itself.
"""
data_1 = deepcopy(get_nstep_return_data(data, 1, gamma=self._gamma))
data = get_nstep_return_data(
data, self._nstep, gamma=self._gamma
) # here we want to include one-step next observation
for i in range(len(data)):
data[i]['next_obs_1'] = data_1[i]['next_obs'] # concat the one-step next observation
data[i]['done_1'] = data_1[i]['done']
return get_train_sample(data, self._unroll_len)
......@@ -9,7 +9,7 @@ from .td import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_e
q_nstep_td_error_with_rescale, v_1step_td_data, v_1step_td_error, v_nstep_td_data, v_nstep_td_error, \
generalized_lambda_returns, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_error, dist_nstep_td_data, \
nstep_return_data, nstep_return, iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error,\
q_nstep_sql_td_error
q_nstep_sql_td_error, dqfd_nstep_td_error, dqfd_nstep_td_data
from .vtrace import vtrace_loss, compute_importance_weights
from .upgo import upgo_loss
from .adder import get_gae, get_gae_with_default_last_value, get_nstep_return_data, get_train_sample
......
......@@ -259,6 +259,13 @@ q_nstep_td_data = namedtuple(
'q_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight']
)
dqfd_nstep_td_data = namedtuple(
'dqfd_nstep_td_data', [
'q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'done_1', 'weight', 'new_n_q_one_step',
'next_n_action_one_step', 'is_expert'
]
)
def shape_fn_qntd(args, kwargs):
r"""
......@@ -329,6 +336,107 @@ def q_nstep_td_error(
return (td_error_per_sample * weight).mean(), td_error_per_sample
def dqfd_nstep_td_error(
data: namedtuple,
gamma: float,
lambda1: tuple,
lambda2: tuple,
margin_function: float,
nstep: int = 1,
cum_reward: bool = False,
value_gamma: Optional[torch.Tensor] = None,
criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
) -> torch.Tensor:
"""
Overview:
Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd
Arguments:
- data (:obj:`dqfd_nstep_td_data`): the input data, dqfd_nstep_td_data to calculate loss
- gamma (:obj:`float`): discount factor
- cum_reward (:obj:`bool`): whether to use cumulative nstep reward, which is figured out when collecting data
- value_gamma (:obj:`torch.Tensor`): gamma discount value for target q_value
- criterion (:obj:`torch.nn.modules`): loss function criterion
- nstep (:obj:`int`): nstep num, default set to 10
Returns:
- loss (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor
- td_error_per_sample (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error\
+ supervised margin loss, 1-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): the q_nstep_td_data containing\
['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'\
, 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`
- action (:obj:`torch.LongTensor`): :math:`(B, )`
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`
- new_n_q_one_step (:obj:`torch.FloatTensor`): :math:`(B, N)`
- next_n_action_one_step (:obj:`torch.LongTensor`): :math:`(B, )`
- is_expert (:obj:`int`) : 0 or 1
"""
q, next_n_q, action, next_n_action, reward, done, done_1, weight, new_n_q_one_step, next_n_action_one_step,\
is_expert = data # set is_expert flag(expert 1, agent 0)
assert len(action.shape) == 1, action.shape
if weight is None:
weight = torch.ones_like(action)
batch_range = torch.arange(action.shape[0])
q_s_a = q[batch_range, action]
target_q_s_a = next_n_q[batch_range, next_n_action]
target_q_s_a_one_step = new_n_q_one_step[batch_range, next_n_action_one_step]
# calculate n-step TD-loss
if cum_reward:
if value_gamma is None:
target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done)
else:
target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done)
else:
target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
# calculate 1-step TD-loss
nstep = 1
reward = reward[0].unsqueeze(0)
value_gamma = None
if cum_reward:
if value_gamma is None:
target_q_s_a_one_step = reward + (gamma ** nstep) * target_q_s_a_one_step * (1 - done_1)
else:
target_q_s_a_one_step = reward + value_gamma * target_q_s_a_one_step * (1 - done_1)
else:
target_q_s_a_one_step = nstep_return(
nstep_return_data(reward, target_q_s_a_one_step, done_1), gamma, nstep, value_gamma
)
td_error_one_step_per_sample = criterion(q_s_a, target_q_s_a_one_step.detach())
# calculate the supervised loss
device = q_s_a.device
device_cpu = torch.device('cpu')
'''
max_action = torch.argmax(q, dim=-1)
JE = is_expert * (
q[batch_range, max_action] + margin_function *
torch.where(action == max_action, torch.ones_like(action), torch.zeros_like(action)).float().to(device) - q_s_a
)
'''
l = margin_function * torch.ones_like(q).to(device_cpu)
l.scatter_(
1, torch.LongTensor(action.unsqueeze(1).to(device_cpu)), torch.zeros_like(q, device=device_cpu)
) # along the first dimension. for the index of the action, fill the corresponding position in l with 0
JE = is_expert * (torch.max(q + l.to(device), dim=1)[0] - q_s_a)
'''
Js = is_expert * (
q[batch_range, max_action.type(torch.int64)] +
0.8 * torch.from_numpy((action == max_action).numpy().astype(int)).float().to(device) - q_s_a
)
'''
return ((lambda1[0] * td_error_per_sample + td_error_one_step_per_sample + lambda2[0] * JE) *
weight).mean(), td_error_per_sample + td_error_one_step_per_sample + JE
def shape_fn_qntd_rescale(args, kwargs):
r"""
Overview:
......
import pytest
import torch
from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_error, td_lambda_data,\
td_lambda_error, q_nstep_td_error_with_rescale, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_data, \
dist_nstep_td_error, v_1step_td_data, v_1step_td_error, v_nstep_td_data, v_nstep_td_error, q_nstep_sql_td_error, \
iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error
td_lambda_error, q_nstep_td_error_with_rescale, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_data,\
dqfd_nstep_td_data, dqfd_nstep_td_error, dist_nstep_td_error, v_1step_td_data, v_1step_td_error, v_nstep_td_data,\
v_nstep_td_error, q_nstep_sql_td_error, iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data,\
qrdqn_nstep_td_error
from ding.rl_utils.td import shape_fn_dntd, shape_fn_qntd, shape_fn_td_lambda, shape_fn_qntd_rescale
......@@ -214,6 +215,35 @@ def test_v_nstep_td():
assert isinstance(v.grad, torch.Tensor)
@pytest.mark.unittest
def test_dqfd_nstep_td():
batch_size = 4
action_dim = 3
next_q = torch.randn(batch_size, action_dim)
done = torch.randn(batch_size)
done_1 = torch.randn(batch_size)
next_q_one_step = torch.randn(batch_size, action_dim)
action = torch.randint(0, action_dim, size=(batch_size, ))
next_action = torch.randint(0, action_dim, size=(batch_size, ))
next_action_one_step = torch.randint(0, action_dim, size=(batch_size, ))
is_expert = torch.ones((batch_size))
for nstep in range(1, 10):
q = torch.randn(batch_size, action_dim).requires_grad_(True)
reward = torch.rand(nstep, batch_size)
data = dqfd_nstep_td_data(
q, next_q, action, next_action, reward, done, done_1, None, next_q_one_step, next_action_one_step, is_expert
)
loss, td_error_per_sample = dqfd_nstep_td_error(
data, 0.95, lambda1=(1, ), lambda2=(1, ), margin_function=0.8, nstep=nstep
)
assert td_error_per_sample.shape == (batch_size, )
assert loss.shape == ()
assert q.grad is None
loss.backward()
assert isinstance(q.grad, torch.Tensor)
print(loss)
@pytest.mark.unittest
def test_q_nstep_sql_td():
batch_size = 4
......
from copy import deepcopy
from ding.entry import serial_pipeline
from easydict import EasyDict
pong_dqfd_config = dict(
exp_name='pong_dqfd',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=20,
env_id='PongNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=True, force_reproducibility=True)
),
policy=dict(
cuda=True,
priority=True,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
lambda1 = 1.0,
lambda2 = 1.0,
lambda3 = 1e-5,
per_train_iter_k = 10,
expert_replay_buffer_size = 10000, # justify the buffer size of the expert buffer
),
collect=dict(n_sample=96, demonstration_info_path = 'path'), #Users should add their own path here (path should lead to a well-trained model)
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=250000,
),
replay_buffer=dict(replay_buffer_size=100000, ),
),
),
)
pong_dqfd_config = EasyDict(pong_dqfd_config)
main_config = pong_dqfd_config
pong_dqfd_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqfd'),
)
pong_dqfd_create_config = EasyDict(pong_dqfd_create_config)
create_config = pong_dqfd_create_config
if __name__ == '__main__':
serial_pipeline((main_config, create_config), seed=0)
from ding.entry import serial_pipeline
from easydict import EasyDict
qbert_dqn_config = dict(
exp_name='qbert_dqfd',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=30000,
env_id='QbertNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=True, force_reproducibility=True)
),
policy=dict(
cuda=True,
priority=True,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
lambda1 = 1.0,
lambda2 = 1.0,
lambda3 = 1e-5,
per_train_iter_k = 10,
expert_replay_buffer_size = 10000, # justify the buffer size of the expert buffer
),
collect=dict(n_sample=100, demonstration_info_path = 'path'), #Users should add their own path here (path should lead to a well-trained model)
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=1000000,
),
replay_buffer=dict(replay_buffer_size=400000, ),
),
),
)
qbert_dqn_config = EasyDict(qbert_dqn_config)
main_config = qbert_dqn_config
qbert_dqn_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqfd'),
)
qbert_dqn_create_config = EasyDict(qbert_dqn_create_config)
create_config = qbert_dqn_create_config
if __name__ == '__main__':
serial_pipeline((main_config, create_config), seed=0)
from copy import deepcopy
from ding.entry import serial_pipeline
from easydict import EasyDict
space_invaders_dqfd_config = dict(
exp_name='space_invaders_dqfd',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=10000000000,
env_id='SpaceInvadersNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=True, force_reproducibility=True)
),
policy=dict(
cuda=True,
priority=True,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
lambda1 = 1.0,
lambda2 = 1.0,
lambda3 = 1e-5,
per_train_iter_k = 10,
expert_replay_buffer_size = 10000, # justify the buffer size of the expert buffer
),
collect=dict(n_sample=100, demonstration_info_path = 'path'), #Users should add their own path here (path should lead to a well-trained model)
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=1000000,
),
replay_buffer=dict(replay_buffer_size=400000, ),
),
),
)
space_invaders_dqfd_config = EasyDict(space_invaders_dqfd_config)
main_config = space_invaders_dqfd_config
space_invaders_dqfd_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqfd'),
)
space_invaders_dqfd_create_config = EasyDict(space_invaders_dqfd_create_config)
create_config = space_invaders_dqfd_create_config
if __name__ == '__main__':
serial_pipeline((main_config, create_config), seed=0)
from easydict import EasyDict
from ding.entry import serial_pipeline
lunarlander_dqfd_config = dict(
exp_name='lunarlander_dqfd',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True),
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=200,
),
policy=dict(
cuda=True,
model=dict(
obs_shape=8,
action_shape=4,
encoder_hidden_size_list=[512, 64],
dueling=True,
),
nstep=3,
discount_factor=0.97,
learn=dict(batch_size=64, learning_rate=0.001,
lambda1 = 1.0,
lambda2 = 1.0,
lambda3 = 1e-5,
per_train_iter_k = 10,
expert_replay_buffer_size = 10000, # justify the buffer size of the expert buffer
),
collect=dict(
n_sample=64,
# Users should add their own path here (path should lead to a well-trained model)
demonstration_info_path='path',
# Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
),
eval=dict(evaluator=dict(eval_freq=50, )), # note: this is the times after which you learns to evaluate
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=10000,
),
replay_buffer=dict(replay_buffer_size=20000, ),
),
),
)
lunarlander_dqfd_config = EasyDict(lunarlander_dqfd_config)
main_config = lunarlander_dqfd_config
lunarlander_dqfd_create_config = dict(
env=dict(
type='lunarlander',
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqfd'),
)
lunarlander_dqfd_create_config = EasyDict(lunarlander_dqfd_create_config)
create_config = lunarlander_dqfd_create_config
if __name__ == "__main__":
serial_pipeline([main_config, create_config], seed=0)
......@@ -24,7 +24,7 @@ lunarlander_dqn_default_config = dict(
# Whether to use dueling head.
dueling=True,
),
# Reward's future discount facotr, aka. gamma.
# Reward's future discount factor, aka. gamma.
discount_factor=0.99,
# How many steps in td error.
nstep=nstep,
......@@ -33,7 +33,7 @@ lunarlander_dqn_default_config = dict(
update_per_collect=10,
batch_size=64,
learning_rate=0.001,
# Frequence of target network update.
# Frequency of target network update.
target_update_freq=100,
),
# collect_mode config
......@@ -41,7 +41,7 @@ lunarlander_dqn_default_config = dict(
# You can use either "n_sample" or "n_episode" in collector.collect.
# Get "n_sample" samples per collect.
n_sample=64,
# Cut trajectories into pieces with length "unrol_len".
# Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
),
# command_mode config
......
......@@ -11,4 +11,7 @@ from .cartpole_sqn_config import cartpole_sqn_config, cartpole_sqn_create_config
from .cartpole_ppg_config import cartpole_ppg_config, cartpole_ppg_create_config
from .cartpole_r2d2_config import cartpole_r2d2_config, cartpole_r2d2_create_config
from .cartpole_acer_config import cartpole_acer_config, cartpole_acer_create_config
from .cartpole_dqfd_config import cartpole_dqfd_config, cartpole_dqfd_create_config
from .cartpole_sqil_config import cartpole_sqil_config, cartpole_sqil_create_config
from .cartpole_sql_config import cartpole_sql_config, cartpole_sql_create_config
# from .cartpole_ppo_default_loader import cartpole_ppo_default_loader
from easydict import EasyDict
cartpole_dqfd_config = dict(
exp_name='cartpole_dqfd',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=195,
),
policy=dict(
cuda=True,
priority=True,
model=dict(
obs_shape=4,
action_shape=2,
encoder_hidden_size_list=[128, 128, 64],
dueling=True,
),
nstep=3,
discount_factor=0.97,
learn=dict(
batch_size=64,
learning_rate=0.001,
lambda1 = 1,
lambda2 = 3.0,
lambda3 = 0, # set this to be 0 (L2 loss = 0) with expert_replay_buffer_size = 0 and lambda1 = 0 recover the one step pdd dqn
per_train_iter_k = 10,
expert_replay_buffer_size = 10000, # justify the buffer size of the expert buffer
),
# Users should add their own path here (path should lead to a well-trained model)
collect=dict(n_sample=8, demonstration_info_path = 'path'),
# note: this is the times after which you learns to evaluate
eval=dict(evaluator=dict(eval_freq=50, )),
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=10000,
),
replay_buffer=dict(replay_buffer_size=20000, ),
),
),
)
cartpole_dqfd_config = EasyDict(cartpole_dqfd_config)
main_config = cartpole_dqfd_config
cartpole_dqfd_create_config = dict(
env=dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqfd'),
)
cartpole_dqfd_create_config = EasyDict(cartpole_dqfd_create_config)
create_config = cartpole_dqfd_create_config
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册