未验证 提交 e2ca8738 编写于 作者: W Will-Nie 提交者: GitHub

feature(nyp): add DQfD algorithm (#48)

* add_dqfd

* Is_expert to is_expert

* modify according to the last commnets

* value_gamma; done; marginloss; sqil compatibility

* finally shorten the code, revise config

* revise config, style

* add_readme/two_more_config

* correct format
Co-authored-by: Nniuyazhe <niuyazhe@sensetime.com>
上级 8efee984
...@@ -120,12 +120,13 @@ ding -m serial -e cartpole -p dqn -s 0 ...@@ -120,12 +120,13 @@ ding -m serial -e cartpole -p dqn -s 0
| 20 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 | | 20 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 |
| 21 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_reward_model -c cartpole_dqn_config.py -s 0 | | 21 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_reward_model -c cartpole_dqn_config.py -s 0 |
| 22 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 | | 22 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 |
| 23 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py | | 23 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![IL](https://img.shields.io/badge/-discrete-brightgreen) | [policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 |
| 24 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py | | 24 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
| 25 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py | | 25 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py |
| 26 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` | | 26 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
| 27 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` | | 27 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 28 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py | | 28 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
| 29 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py |
![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space, which is only label in normal DRL algorithms(1-15) ![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space, which is only label in normal DRL algorithms(1-15)
......
...@@ -4,5 +4,7 @@ from .serial_entry_onpolicy import serial_pipeline_onpolicy ...@@ -4,5 +4,7 @@ from .serial_entry_onpolicy import serial_pipeline_onpolicy
from .serial_entry_offline import serial_pipeline_offline from .serial_entry_offline import serial_pipeline_offline
from .serial_entry_il import serial_pipeline_il from .serial_entry_il import serial_pipeline_il
from .serial_entry_reward_model import serial_pipeline_reward_model from .serial_entry_reward_model import serial_pipeline_reward_model
from .serial_entry_dqfd import serial_pipeline_dqfd
from .serial_entry_sqil import serial_pipeline_sqil
from .parallel_entry import parallel_pipeline from .parallel_entry import parallel_pipeline
from .application_entry import eval, collect_demo_data from .application_entry import eval, collect_demo_data
...@@ -52,7 +52,7 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) ...@@ -52,7 +52,7 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.option( @click.option(
'-m', '-m',
'--mode', '--mode',
type=click.Choice(['serial', 'serial_onpolicy', 'serial_sqil', 'parallel', 'dist', 'eval']), type=click.Choice(['serial', 'serial_onpolicy', 'serial_sqil', 'serial_dqfd', 'parallel', 'dist', 'eval']),
help='serial-train or parallel-train or dist-train or eval' help='serial-train or parallel-train or dist-train or eval'
) )
@click.option('-c', '--config', type=str, help='Path to DRL experiment config') @click.option('-c', '--config', type=str, help='Path to DRL experiment config')
...@@ -157,6 +157,12 @@ def cli( ...@@ -157,6 +157,12 @@ def cli(
config = get_predefined_config(env, policy) config = get_predefined_config(env, policy)
expert_config = input("Enter the name of the config you used to generate your expert model: ") expert_config = input("Enter the name of the config you used to generate your expert model: ")
serial_pipeline_sqil(config, expert_config, seed, max_iterations=train_iter) serial_pipeline_sqil(config, expert_config, seed, max_iterations=train_iter)
elif mode == 'serial_dqfd':
from .serial_entry_dqfd import serial_pipeline_dqfd
if config is None:
config = get_predefined_config(env, policy)
expert_config = input("Enter the name of the config you used to generate your expert model: ")
serial_pipeline_dqfd(config, expert_config, seed, max_iterations=train_iter)
elif mode == 'parallel': elif mode == 'parallel':
from .parallel_entry import parallel_pipeline from .parallel_entry import parallel_pipeline
parallel_pipeline(config, seed, enable_total_log, disable_flask_log) parallel_pipeline(config, seed, enable_total_log, disable_flask_log)
......
此差异已折叠。
import pytest
import torch
from copy import deepcopy
from ding.entry import serial_pipeline
from ding.entry.serial_entry_dqfd import serial_pipeline_dqfd
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
from dizoo.classic_control.cartpole.config.cartpole_dqfd_config import cartpole_dqfd_config, cartpole_dqfd_create_config
@pytest.mark.unittest
def test_dqfd():
expert_policy_state_dict_path = './expert_policy.pth'
config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)]
expert_policy = serial_pipeline(config, seed=0)
torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)
config = [deepcopy(cartpole_dqfd_config), deepcopy(cartpole_dqfd_create_config)]
config[0].policy.collect.demonstration_info_path = expert_policy_state_dict_path
config[0].policy.learn.update_per_collect = 1
try:
serial_pipeline_dqfd(config, [cartpole_dqfd_config, cartpole_dqfd_create_config], seed=0, max_iterations=1)
except Exception:
assert False, "pipeline fail"
...@@ -24,6 +24,7 @@ from .atoc import ATOCPolicy ...@@ -24,6 +24,7 @@ from .atoc import ATOCPolicy
from .acer import ACERPolicy from .acer import ACERPolicy
from .qtran import QTRANPolicy from .qtran import QTRANPolicy
from .sql import SQLPolicy from .sql import SQLPolicy
from .dqfd import DQFDPolicy
from .d4pg import D4PGPolicy from .d4pg import D4PGPolicy
from .cql import CQLPolicy, CQLDiscretePolicy from .cql import CQLPolicy, CQLDiscretePolicy
...@@ -81,6 +82,11 @@ class DQNCommandModePolicy(DQNPolicy, EpsCommandModePolicy): ...@@ -81,6 +82,11 @@ class DQNCommandModePolicy(DQNPolicy, EpsCommandModePolicy):
pass pass
@POLICY_REGISTRY.register('dqfd_command')
class DQFDCommandModePolicy(DQFDPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('c51_command') @POLICY_REGISTRY.register('c51_command')
class C51CommandModePolicy(C51Policy, EpsCommandModePolicy): class C51CommandModePolicy(C51Policy, EpsCommandModePolicy):
pass pass
......
from typing import List, Dict, Any, Tuple
from collections import namedtuple
import copy
import torch
from ding.torch_utils import Adam, to_device
from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, get_nstep_return_data, get_train_sample, \
dqfd_nstep_td_error, dqfd_nstep_td_data
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from .dqn import DQNPolicy
from .common_utils import default_preprocess_learn
from copy import deepcopy
@POLICY_REGISTRY.register('dqfd')
class DQFDPolicy(DQNPolicy):
r"""
Overview:
Policy class of DQFD algorithm, extended by Double DQN/Dueling DQN/PER/multi-step TD.
Config:
== ==================== ======== ============== ======================================== =======================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ============== ======================================== =======================
1 ``type`` str dqn | RL policy register name, refer to | This arg is optional,
| registry ``POLICY_REGISTRY`` | a placeholder
2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
| erent from modes
3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
| or off-policy
4 ``priority`` bool True | Whether use priority(PER) | Priority sample,
| update priority
5 | ``priority_IS`` bool True | Whether use Importance Sampling Weight
| ``_weight`` | to correct biased update. If True,
| priority must be True.
6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse
| ``factor`` [0.95, 0.999] | gamma | reward env
7 ``nstep`` int 10, | N-step reward discount sum for target
[3, 5] | q_value estimation
8 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary
| ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
| valid in serial training | means more off-policy
9 | ``learn.batch_`` int 64 | The number of samples of an iteration
| ``size``
10 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
| ``_rate``
11 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update
| ``update_freq``
12 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
| ``done`` | calculation. | fake termination env
13 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
| call of collector. | different envs
14 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
| ``_len``
== ==================== ======== ============== ======================================== =======================
"""
config = dict(
type='dqfd',
cuda=False,
on_policy=False,
priority=True,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=True,
discount_factor=0.99,
nstep=10,
learn=dict(
# multiplicative factor for each loss
lambda1=1.0,
lambda2=1.0,
lambda3=1e-5,
# margin function in JE, here we implement this as a constant
margin_function=0.8,
# number of pertraining iterations
per_train_iter_k=10,
# (bool) Whether to use multi gpu
multi_gpu=False,
# How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=3,
batch_size=64,
learning_rate=0.001,
# ==============================================================
# The following configs are algorithm-specific
# ==============================================================
# (int) Frequence of target network update.
target_update_freq=100,
# (bool) Whether ignore done(usually for max step termination env)
ignore_done=False,
),
# collect_mode config
collect=dict(
# (int) Only one of [n_sample, n_episode] should be set
# n_sample=8,
# (int) Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
# The hyperparameter pho, the demo ratio, control the propotion of data\
# coming from expert demonstrations versus from the agent's own experience.
pho=0.5,
),
eval=dict(),
# other config
other=dict(
# Epsilon greedy with decay.
eps=dict(
# (str) Decay type. Support ['exp', 'linear'].
type='exp',
start=0.95,
end=0.1,
# (int) Decay length(env step)
decay=10000,
),
replay_buffer=dict(replay_buffer_size=10000, ),
),
)
def _init_learn(self) -> None:
"""
Overview:
Learn mode init method. Called by ``self.__init__``, initialize the optimizer, algorithm arguments, main \
and target model.
"""
self.lambda1 = self._cfg.learn.lambda1, # n-step return
self.lambda2 = self._cfg.learn.lambda2, # supervised loss
self.lambda3 = self._cfg.learn.lambda3, # L2
# margin function in JE, here we implement this as a constant
self.margin_function = self._cfg.learn.margin_function
self._priority = self._cfg.priority
self._priority_IS_weight = self._cfg.priority_IS_weight
# Optimizer
self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate, weight_decay=self.lambda3[0])
self._gamma = self._cfg.discount_factor
self._nstep = self._cfg.nstep
# use model_wrapper for specialized demands of different modes
self._target_model = copy.deepcopy(self._model)
self._target_model = model_wrap(
self._target_model,
wrapper_name='target',
update_type='assign',
update_kwargs={'freq': self._cfg.learn.target_update_freq}
)
self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
self._learn_model.reset()
self._target_model.reset()
def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Overview:
Forward computation graph of learn mode(updating policy).
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
np.ndarray or dict/list combinations.
Returns:
- info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
recorded in text log and tensorboard, values are python scalar or a list of scalars.
ArgumentsKeys:
- necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done``
- optional: ``value_gamma``, ``IS``
ReturnsKeys:
- necessary: ``cur_lr``, ``total_loss``, ``priority``
- optional: ``action_distribution``
"""
data = default_preprocess_learn(
data,
use_priority=self._priority,
use_priority_IS_weight=self._cfg.priority_IS_weight,
ignore_done=self._cfg.learn.ignore_done,
use_nstep=True
)
data['done_1'] = data['done_1'].float()
if self._cuda:
data = to_device(data, self._device)
# ====================
# Q-learning forward
# ====================
self._learn_model.train()
self._target_model.train()
# Current q value (main model)
q_value = self._learn_model.forward(data['obs'])['logit']
# Target q value
with torch.no_grad():
target_q_value = self._target_model.forward(data['next_obs'])['logit']
target_q_value_one_step = self._target_model.forward(data['next_obs_1'])['logit']
# Max q value action (main model)
target_q_action = self._learn_model.forward(data['next_obs'])['action']
target_q_action_one_step = self._learn_model.forward(data['next_obs_1'])['action']
data_n = dqfd_nstep_td_data(
q_value,
target_q_value,
data['action'],
target_q_action,
data['reward'],
data['done'],
data['done_1'],
data['weight'],
target_q_value_one_step,
target_q_action_one_step,
data['is_expert'] # set is_expert flag(expert 1, agent 0)
)
value_gamma = data.get('value_gamma')
loss, td_error_per_sample = dqfd_nstep_td_error(
data_n,
self._gamma,
self.lambda1,
self.lambda2,
self.margin_function,
nstep=self._nstep,
value_gamma=value_gamma
)
# ====================
# Q-learning update
# ====================
self._optimizer.zero_grad()
loss.backward()
if self._cfg.learn.multi_gpu:
self.sync_gradients(self._learn_model)
self._optimizer.step()
# =============
# after update
# =============
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': loss.item(),
'priority': td_error_per_sample.abs().tolist(),
# Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard.
# '[histogram]action_distribution': data['action'],
}
def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Overview:
For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \
can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \
or some continuous transitions(DRQN).
Arguments:
- data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \
format as the return value of ``self._process_transition`` method.
Returns:
- samples (:obj:`dict`): The list of training samples.
.. note::
We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \
And the user can customize the this data processing procecure by overriding this two methods and collector \
itself.
"""
data_1 = deepcopy(get_nstep_return_data(data, 1, gamma=self._gamma))
data = get_nstep_return_data(
data, self._nstep, gamma=self._gamma
) # here we want to include one-step next observation
for i in range(len(data)):
data[i]['next_obs_1'] = data_1[i]['next_obs'] # concat the one-step next observation
data[i]['done_1'] = data_1[i]['done']
return get_train_sample(data, self._unroll_len)
...@@ -9,7 +9,7 @@ from .td import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_e ...@@ -9,7 +9,7 @@ from .td import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_e
q_nstep_td_error_with_rescale, v_1step_td_data, v_1step_td_error, v_nstep_td_data, v_nstep_td_error, \ q_nstep_td_error_with_rescale, v_1step_td_data, v_1step_td_error, v_nstep_td_data, v_nstep_td_error, \
generalized_lambda_returns, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_error, dist_nstep_td_data, \ generalized_lambda_returns, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_error, dist_nstep_td_data, \
nstep_return_data, nstep_return, iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error,\ nstep_return_data, nstep_return, iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error,\
q_nstep_sql_td_error q_nstep_sql_td_error, dqfd_nstep_td_error, dqfd_nstep_td_data
from .vtrace import vtrace_loss, compute_importance_weights from .vtrace import vtrace_loss, compute_importance_weights
from .upgo import upgo_loss from .upgo import upgo_loss
from .adder import get_gae, get_gae_with_default_last_value, get_nstep_return_data, get_train_sample from .adder import get_gae, get_gae_with_default_last_value, get_nstep_return_data, get_train_sample
......
...@@ -259,6 +259,13 @@ q_nstep_td_data = namedtuple( ...@@ -259,6 +259,13 @@ q_nstep_td_data = namedtuple(
'q_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'] 'q_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight']
) )
dqfd_nstep_td_data = namedtuple(
'dqfd_nstep_td_data', [
'q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'done_1', 'weight', 'new_n_q_one_step',
'next_n_action_one_step', 'is_expert'
]
)
def shape_fn_qntd(args, kwargs): def shape_fn_qntd(args, kwargs):
r""" r"""
...@@ -329,6 +336,107 @@ def q_nstep_td_error( ...@@ -329,6 +336,107 @@ def q_nstep_td_error(
return (td_error_per_sample * weight).mean(), td_error_per_sample return (td_error_per_sample * weight).mean(), td_error_per_sample
def dqfd_nstep_td_error(
data: namedtuple,
gamma: float,
lambda1: tuple,
lambda2: tuple,
margin_function: float,
nstep: int = 1,
cum_reward: bool = False,
value_gamma: Optional[torch.Tensor] = None,
criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
) -> torch.Tensor:
"""
Overview:
Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd
Arguments:
- data (:obj:`dqfd_nstep_td_data`): the input data, dqfd_nstep_td_data to calculate loss
- gamma (:obj:`float`): discount factor
- cum_reward (:obj:`bool`): whether to use cumulative nstep reward, which is figured out when collecting data
- value_gamma (:obj:`torch.Tensor`): gamma discount value for target q_value
- criterion (:obj:`torch.nn.modules`): loss function criterion
- nstep (:obj:`int`): nstep num, default set to 10
Returns:
- loss (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor
- td_error_per_sample (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error\
+ supervised margin loss, 1-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): the q_nstep_td_data containing\
['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'\
, 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`
- action (:obj:`torch.LongTensor`): :math:`(B, )`
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`
- new_n_q_one_step (:obj:`torch.FloatTensor`): :math:`(B, N)`
- next_n_action_one_step (:obj:`torch.LongTensor`): :math:`(B, )`
- is_expert (:obj:`int`) : 0 or 1
"""
q, next_n_q, action, next_n_action, reward, done, done_1, weight, new_n_q_one_step, next_n_action_one_step,\
is_expert = data # set is_expert flag(expert 1, agent 0)
assert len(action.shape) == 1, action.shape
if weight is None:
weight = torch.ones_like(action)
batch_range = torch.arange(action.shape[0])
q_s_a = q[batch_range, action]
target_q_s_a = next_n_q[batch_range, next_n_action]
target_q_s_a_one_step = new_n_q_one_step[batch_range, next_n_action_one_step]
# calculate n-step TD-loss
if cum_reward:
if value_gamma is None:
target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done)
else:
target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done)
else:
target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
# calculate 1-step TD-loss
nstep = 1
reward = reward[0].unsqueeze(0)
value_gamma = None
if cum_reward:
if value_gamma is None:
target_q_s_a_one_step = reward + (gamma ** nstep) * target_q_s_a_one_step * (1 - done_1)
else:
target_q_s_a_one_step = reward + value_gamma * target_q_s_a_one_step * (1 - done_1)
else:
target_q_s_a_one_step = nstep_return(
nstep_return_data(reward, target_q_s_a_one_step, done_1), gamma, nstep, value_gamma
)
td_error_one_step_per_sample = criterion(q_s_a, target_q_s_a_one_step.detach())
# calculate the supervised loss
device = q_s_a.device
device_cpu = torch.device('cpu')
'''
max_action = torch.argmax(q, dim=-1)
JE = is_expert * (
q[batch_range, max_action] + margin_function *
torch.where(action == max_action, torch.ones_like(action), torch.zeros_like(action)).float().to(device) - q_s_a
)
'''
l = margin_function * torch.ones_like(q).to(device_cpu)
l.scatter_(
1, torch.LongTensor(action.unsqueeze(1).to(device_cpu)), torch.zeros_like(q, device=device_cpu)
) # along the first dimension. for the index of the action, fill the corresponding position in l with 0
JE = is_expert * (torch.max(q + l.to(device), dim=1)[0] - q_s_a)
'''
Js = is_expert * (
q[batch_range, max_action.type(torch.int64)] +
0.8 * torch.from_numpy((action == max_action).numpy().astype(int)).float().to(device) - q_s_a
)
'''
return ((lambda1[0] * td_error_per_sample + td_error_one_step_per_sample + lambda2[0] * JE) *
weight).mean(), td_error_per_sample + td_error_one_step_per_sample + JE
def shape_fn_qntd_rescale(args, kwargs): def shape_fn_qntd_rescale(args, kwargs):
r""" r"""
Overview: Overview:
......
import pytest import pytest
import torch import torch
from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_error, td_lambda_data,\ from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_error, td_lambda_data,\
td_lambda_error, q_nstep_td_error_with_rescale, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_data, \ td_lambda_error, q_nstep_td_error_with_rescale, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_data,\
dist_nstep_td_error, v_1step_td_data, v_1step_td_error, v_nstep_td_data, v_nstep_td_error, q_nstep_sql_td_error, \ dqfd_nstep_td_data, dqfd_nstep_td_error, dist_nstep_td_error, v_1step_td_data, v_1step_td_error, v_nstep_td_data,\
iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error v_nstep_td_error, q_nstep_sql_td_error, iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data,\
qrdqn_nstep_td_error
from ding.rl_utils.td import shape_fn_dntd, shape_fn_qntd, shape_fn_td_lambda, shape_fn_qntd_rescale from ding.rl_utils.td import shape_fn_dntd, shape_fn_qntd, shape_fn_td_lambda, shape_fn_qntd_rescale
...@@ -214,6 +215,35 @@ def test_v_nstep_td(): ...@@ -214,6 +215,35 @@ def test_v_nstep_td():
assert isinstance(v.grad, torch.Tensor) assert isinstance(v.grad, torch.Tensor)
@pytest.mark.unittest
def test_dqfd_nstep_td():
batch_size = 4
action_dim = 3
next_q = torch.randn(batch_size, action_dim)
done = torch.randn(batch_size)
done_1 = torch.randn(batch_size)
next_q_one_step = torch.randn(batch_size, action_dim)
action = torch.randint(0, action_dim, size=(batch_size, ))
next_action = torch.randint(0, action_dim, size=(batch_size, ))
next_action_one_step = torch.randint(0, action_dim, size=(batch_size, ))
is_expert = torch.ones((batch_size))
for nstep in range(1, 10):
q = torch.randn(batch_size, action_dim).requires_grad_(True)
reward = torch.rand(nstep, batch_size)
data = dqfd_nstep_td_data(
q, next_q, action, next_action, reward, done, done_1, None, next_q_one_step, next_action_one_step, is_expert
)
loss, td_error_per_sample = dqfd_nstep_td_error(
data, 0.95, lambda1=(1, ), lambda2=(1, ), margin_function=0.8, nstep=nstep
)
assert td_error_per_sample.shape == (batch_size, )
assert loss.shape == ()
assert q.grad is None
loss.backward()
assert isinstance(q.grad, torch.Tensor)
print(loss)
@pytest.mark.unittest @pytest.mark.unittest
def test_q_nstep_sql_td(): def test_q_nstep_sql_td():
batch_size = 4 batch_size = 4
......
from copy import deepcopy
from ding.entry import serial_pipeline
from easydict import EasyDict
pong_dqfd_config = dict(
exp_name='pong_dqfd',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=20,
env_id='PongNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=True, force_reproducibility=True)
),
policy=dict(
cuda=True,
priority=True,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
lambda1 = 1.0,
lambda2 = 1.0,
lambda3 = 1e-5,
per_train_iter_k = 10,
expert_replay_buffer_size = 10000, # justify the buffer size of the expert buffer
),
collect=dict(n_sample=96, demonstration_info_path = 'path'), #Users should add their own path here (path should lead to a well-trained model)
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=250000,
),
replay_buffer=dict(replay_buffer_size=100000, ),
),
),
)
pong_dqfd_config = EasyDict(pong_dqfd_config)
main_config = pong_dqfd_config
pong_dqfd_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqfd'),
)
pong_dqfd_create_config = EasyDict(pong_dqfd_create_config)
create_config = pong_dqfd_create_config
if __name__ == '__main__':
serial_pipeline((main_config, create_config), seed=0)
from ding.entry import serial_pipeline
from easydict import EasyDict
qbert_dqn_config = dict(
exp_name='qbert_dqfd',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=30000,
env_id='QbertNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=True, force_reproducibility=True)
),
policy=dict(
cuda=True,
priority=True,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
lambda1 = 1.0,
lambda2 = 1.0,
lambda3 = 1e-5,
per_train_iter_k = 10,
expert_replay_buffer_size = 10000, # justify the buffer size of the expert buffer
),
collect=dict(n_sample=100, demonstration_info_path = 'path'), #Users should add their own path here (path should lead to a well-trained model)
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=1000000,
),
replay_buffer=dict(replay_buffer_size=400000, ),
),
),
)
qbert_dqn_config = EasyDict(qbert_dqn_config)
main_config = qbert_dqn_config
qbert_dqn_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqfd'),
)
qbert_dqn_create_config = EasyDict(qbert_dqn_create_config)
create_config = qbert_dqn_create_config
if __name__ == '__main__':
serial_pipeline((main_config, create_config), seed=0)
from copy import deepcopy
from ding.entry import serial_pipeline
from easydict import EasyDict
space_invaders_dqfd_config = dict(
exp_name='space_invaders_dqfd',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=10000000000,
env_id='SpaceInvadersNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=True, force_reproducibility=True)
),
policy=dict(
cuda=True,
priority=True,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
lambda1 = 1.0,
lambda2 = 1.0,
lambda3 = 1e-5,
per_train_iter_k = 10,
expert_replay_buffer_size = 10000, # justify the buffer size of the expert buffer
),
collect=dict(n_sample=100, demonstration_info_path = 'path'), #Users should add their own path here (path should lead to a well-trained model)
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=1000000,
),
replay_buffer=dict(replay_buffer_size=400000, ),
),
),
)
space_invaders_dqfd_config = EasyDict(space_invaders_dqfd_config)
main_config = space_invaders_dqfd_config
space_invaders_dqfd_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqfd'),
)
space_invaders_dqfd_create_config = EasyDict(space_invaders_dqfd_create_config)
create_config = space_invaders_dqfd_create_config
if __name__ == '__main__':
serial_pipeline((main_config, create_config), seed=0)
from easydict import EasyDict
from ding.entry import serial_pipeline
lunarlander_dqfd_config = dict(
exp_name='lunarlander_dqfd',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True),
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=200,
),
policy=dict(
cuda=True,
model=dict(
obs_shape=8,
action_shape=4,
encoder_hidden_size_list=[512, 64],
dueling=True,
),
nstep=3,
discount_factor=0.97,
learn=dict(batch_size=64, learning_rate=0.001,
lambda1 = 1.0,
lambda2 = 1.0,
lambda3 = 1e-5,
per_train_iter_k = 10,
expert_replay_buffer_size = 10000, # justify the buffer size of the expert buffer
),
collect=dict(
n_sample=64,
# Users should add their own path here (path should lead to a well-trained model)
demonstration_info_path='path',
# Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
),
eval=dict(evaluator=dict(eval_freq=50, )), # note: this is the times after which you learns to evaluate
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=10000,
),
replay_buffer=dict(replay_buffer_size=20000, ),
),
),
)
lunarlander_dqfd_config = EasyDict(lunarlander_dqfd_config)
main_config = lunarlander_dqfd_config
lunarlander_dqfd_create_config = dict(
env=dict(
type='lunarlander',
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqfd'),
)
lunarlander_dqfd_create_config = EasyDict(lunarlander_dqfd_create_config)
create_config = lunarlander_dqfd_create_config
if __name__ == "__main__":
serial_pipeline([main_config, create_config], seed=0)
...@@ -24,7 +24,7 @@ lunarlander_dqn_default_config = dict( ...@@ -24,7 +24,7 @@ lunarlander_dqn_default_config = dict(
# Whether to use dueling head. # Whether to use dueling head.
dueling=True, dueling=True,
), ),
# Reward's future discount facotr, aka. gamma. # Reward's future discount factor, aka. gamma.
discount_factor=0.99, discount_factor=0.99,
# How many steps in td error. # How many steps in td error.
nstep=nstep, nstep=nstep,
...@@ -33,7 +33,7 @@ lunarlander_dqn_default_config = dict( ...@@ -33,7 +33,7 @@ lunarlander_dqn_default_config = dict(
update_per_collect=10, update_per_collect=10,
batch_size=64, batch_size=64,
learning_rate=0.001, learning_rate=0.001,
# Frequence of target network update. # Frequency of target network update.
target_update_freq=100, target_update_freq=100,
), ),
# collect_mode config # collect_mode config
...@@ -41,7 +41,7 @@ lunarlander_dqn_default_config = dict( ...@@ -41,7 +41,7 @@ lunarlander_dqn_default_config = dict(
# You can use either "n_sample" or "n_episode" in collector.collect. # You can use either "n_sample" or "n_episode" in collector.collect.
# Get "n_sample" samples per collect. # Get "n_sample" samples per collect.
n_sample=64, n_sample=64,
# Cut trajectories into pieces with length "unrol_len". # Cut trajectories into pieces with length "unroll_len".
unroll_len=1, unroll_len=1,
), ),
# command_mode config # command_mode config
......
...@@ -11,4 +11,7 @@ from .cartpole_sqn_config import cartpole_sqn_config, cartpole_sqn_create_config ...@@ -11,4 +11,7 @@ from .cartpole_sqn_config import cartpole_sqn_config, cartpole_sqn_create_config
from .cartpole_ppg_config import cartpole_ppg_config, cartpole_ppg_create_config from .cartpole_ppg_config import cartpole_ppg_config, cartpole_ppg_create_config
from .cartpole_r2d2_config import cartpole_r2d2_config, cartpole_r2d2_create_config from .cartpole_r2d2_config import cartpole_r2d2_config, cartpole_r2d2_create_config
from .cartpole_acer_config import cartpole_acer_config, cartpole_acer_create_config from .cartpole_acer_config import cartpole_acer_config, cartpole_acer_create_config
from .cartpole_dqfd_config import cartpole_dqfd_config, cartpole_dqfd_create_config
from .cartpole_sqil_config import cartpole_sqil_config, cartpole_sqil_create_config
from .cartpole_sql_config import cartpole_sql_config, cartpole_sql_create_config
# from .cartpole_ppo_default_loader import cartpole_ppo_default_loader # from .cartpole_ppo_default_loader import cartpole_ppo_default_loader
from easydict import EasyDict
cartpole_dqfd_config = dict(
exp_name='cartpole_dqfd',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=195,
),
policy=dict(
cuda=True,
priority=True,
model=dict(
obs_shape=4,
action_shape=2,
encoder_hidden_size_list=[128, 128, 64],
dueling=True,
),
nstep=3,
discount_factor=0.97,
learn=dict(
batch_size=64,
learning_rate=0.001,
lambda1 = 1,
lambda2 = 3.0,
lambda3 = 0, # set this to be 0 (L2 loss = 0) with expert_replay_buffer_size = 0 and lambda1 = 0 recover the one step pdd dqn
per_train_iter_k = 10,
expert_replay_buffer_size = 10000, # justify the buffer size of the expert buffer
),
# Users should add their own path here (path should lead to a well-trained model)
collect=dict(n_sample=8, demonstration_info_path = 'path'),
# note: this is the times after which you learns to evaluate
eval=dict(evaluator=dict(eval_freq=50, )),
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=10000,
),
replay_buffer=dict(replay_buffer_size=20000, ),
),
),
)
cartpole_dqfd_config = EasyDict(cartpole_dqfd_config)
main_config = cartpole_dqfd_config
cartpole_dqfd_create_config = dict(
env=dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqfd'),
)
cartpole_dqfd_create_config = EasyDict(cartpole_dqfd_create_config)
create_config = cartpole_dqfd_create_config
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册