未验证 提交 16a89c35 编写于 作者: D Davide Liu 提交者: GitHub

feature(davide): Implementation of D4PG (#76)

* added experience replay and n-step

* implementing distributional q value

* added distributional q-value

* added overview in qac_dist and d4pg

* derived D4PG from DDPG

* fixed a bug when action shape >1

* benchmark D4PG mujoco + minor fixs

-entry for DDPG mujoco
-entry for D4PG mujoco
-config for D4PG mujoco
-fixed style D4PG code
-unittests for QAC distributional

* formatted code

* minor updates (read description)

-added d4pg seria_entry test
-updated comments in QACDIST
-added d4pg in commander register
-added q_value in d4pg return dict
-added priority update in d4pg entry
-added assertion in QACDIST
上级 206186f1
......@@ -125,6 +125,7 @@ ding -m serial -e cartpole -p dqn -s 0
| 25 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
| 26 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 27 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
| 28 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py |
![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space, which is only label in normal DRL algorithms(1-15)
......
......@@ -21,6 +21,7 @@ from dizoo.classic_control.cartpole.config.cartpole_r2d2_config import cartpole_
from dizoo.classic_control.pendulum.config import pendulum_ddpg_config, pendulum_ddpg_create_config
from dizoo.classic_control.pendulum.config import pendulum_td3_config, pendulum_td3_create_config
from dizoo.classic_control.pendulum.config import pendulum_sac_config, pendulum_sac_create_config
from dizoo.classic_control.pendulum.config import pendulum_d4pg_config, pendulum_d4pg_create_config
from dizoo.classic_control.bitflip.config import bitflip_her_dqn_config, bitflip_her_dqn_create_config
from dizoo.classic_control.bitflip.entry.bitflip_dqn_main import main as bitflip_dqn_main
from dizoo.multiagent_particle.config import cooperative_navigation_qmix_config, cooperative_navigation_qmix_create_config # noqa
......@@ -358,16 +359,20 @@ def test_cql():
@pytest.mark.unittest
def test_discrete_cql():
# train expert
config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)]
def test_d4pg():
config = [deepcopy(pendulum_d4pg_config), deepcopy(pendulum_d4pg_create_config)]
config[0].policy.learn.update_per_collect = 1
config[0].exp_name = 'cartpole'
try:
serial_pipeline(config, seed=0, max_iterations=1)
except Exception:
assert False, "pipeline fail"
def test_discrete_cql():
# train expert
config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)]
config[0].policy.learn.update_per_collect = 1
config[0].exp_name = 'cartpole'
# collect expert data
import torch
config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)]
......@@ -390,4 +395,4 @@ def test_discrete_cql():
except Exception:
assert False, "pipeline fail"
finally:
os.popen('rm -rf cartpole cartpole_cql')
os.popen('rm -rf cartpole cartpole_cql')
\ No newline at end of file
......@@ -12,3 +12,4 @@ from .sqn import SQN
from .acer import ACER
from .qtran import QTran
from .mappo import MAPPO
from .qac_dist import QACDIST
from typing import Union, Dict, Optional
import torch
import torch.nn as nn
from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
from ..common import RegressionHead, ReparameterizationHead, DistributionHead
@MODEL_REGISTRY.register('qac_dist')
class QACDIST(nn.Module):
r"""
Overview:
The QAC model with distributional Q-value.
Interfaces:
``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
"""
mode = ['compute_actor', 'compute_critic']
def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
actor_head_type: str = "regression",
critic_head_type: str = "categorical",
actor_head_hidden_size: int = 64,
actor_head_layer_num: int = 1,
critic_head_hidden_size: int = 64,
critic_head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
v_min: Optional[float] = -10,
v_max: Optional[float] = 10,
n_atom: Optional[int] = 51,
) -> None:
r"""
Overview:
Init the QAC Distributional Model according to arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
- action_shape (:obj:`Union[int, SequenceType]`): Action's space.
- actor_head_type (:obj:`str`): Whether choose ``regression`` or ``reparameterization``.
- critic_head_type (:obj:`str`): Only ``categorical``.
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
- actor_head_layer_num (:obj:`int`):
The num of layers used in the network to compute Q value output for actor's nn.
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``.
- critic_head_layer_num (:obj:`int`):
The num of layers used in the network to compute Q value output for critic's nn.
- activation (:obj:`Optional[nn.Module]`):
The type of activation function to use in ``MLP`` the after ``layer_fn``,
if ``None`` then default set to ``nn.ReLU()``
- norm_type (:obj:`Optional[str]`):
The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details.
- v_min (:obj:`int`): Value of the smallest atom
- v_max (:obj:`int`): Value of the largest atom
- n_atom (:obj:`int`): Number of atoms in the support
"""
super(QACDIST, self).__init__()
obs_shape: int = squeeze(obs_shape)
action_shape: int = squeeze(action_shape)
self.actor_head_type = actor_head_type
assert self.actor_head_type in ['regression', 'reparameterization']
if self.actor_head_type == 'regression':
self.actor = nn.Sequential(
nn.Linear(obs_shape, actor_head_hidden_size), activation,
RegressionHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
final_tanh=True,
activation=activation,
norm_type=norm_type
)
)
elif self.actor_head_type == 'reparameterization':
self.actor = nn.Sequential(
nn.Linear(obs_shape, actor_head_hidden_size), activation,
ReparameterizationHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
sigma_type='conditioned',
activation=activation,
norm_type=norm_type
)
)
self.critic_head_type = critic_head_type
assert self.critic_head_type in ['categorical'], self.critic_head_type
if self.critic_head_type == 'categorical':
self.critic = nn.Sequential(
nn.Linear(obs_shape + action_shape, critic_head_hidden_size), activation,
DistributionHead(
critic_head_hidden_size,
1,
critic_head_layer_num,
n_atom=n_atom,
v_min=v_min,
v_max=v_max,
activation=activation,
norm_type=norm_type
)
)
def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
r"""
Overview:
Use observation and action tensor to predict output.
Parameter updates with QACDIST's MLPs forward setup.
Arguments:
Forward with ``'compute_actor'``:
- inputs (:obj:`torch.Tensor`):
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``.
Forward with ``'compute_critic'``, inputs (`Dict`) Necessary Keys:
- ``obs``, ``action`` encoded tensors.
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Outputs of network forward.
Forward with ``'compute_actor'``, Necessary Keys (either):
- action (:obj:`torch.Tensor`): Action tensor with same size as input ``x``.
- logit (:obj:`torch.Tensor`):
Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``.
Forward with ``'compute_critic'``, Necessary Keys:
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
- distribution (:obj:`torch.Tensor`): Q value distribution tensor.
Actor Shapes:
- inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size``
- action (:obj:`torch.Tensor`): :math:`(B, N0)`
- q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size.
Critic Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``
- action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is``action_shape``
- q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
- distribution (:obj:`torch.FloatTensor`): :math:`(B, 1, N3)`, where B is batch size and N3 is ``num_atom``
Actor Examples:
>>> # Regression mode
>>> model = QACDIST(64, 64, 'regression')
>>> inputs = torch.randn(4, 64)
>>> actor_outputs = model(inputs,'compute_actor')
>>> assert actor_outputs['action'].shape == torch.Size([4, 64])
>>> # Reparameterization Mode
>>> model = QACDIST(64, 64, 'reparameterization')
>>> inputs = torch.randn(4, 64)
>>> actor_outputs = model(inputs,'compute_actor')
>>> actor_outputs['logit'][0].shape # mu
>>> torch.Size([4, 64])
>>> actor_outputs['logit'][1].shape # sigma
>>> torch.Size([4, 64])
Critic Examples:
>>> # Categorical mode
>>> inputs = {'obs': torch.randn(4,N), 'action': torch.randn(4,1)}
>>> model = QACDIST(obs_shape=(N, ),action_shape=1,actor_head_type='regression', \
... critic_head_type='categorical', n_atoms=51)
>>> q_value = model(inputs, mode='compute_critic') # q value
>>> assert q_value['q_value'].shape == torch.Size([4, 1])
>>> assert q_value['distribution'].shape == torch.Size([4, 1, 51])
"""
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
return getattr(self, mode)(inputs)
def compute_actor(self, inputs: torch.Tensor) -> Dict:
r"""
Overview:
Use encoded embedding tensor to predict output.
Execute parameter updates with ``'compute_actor'`` mode
Use encoded embedding tensor to predict output.
Arguments:
- inputs (:obj:`torch.Tensor`):
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
``hidden_size = actor_head_hidden_size``
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Outputs of forward pass encoder and head.
ReturnsKeys (either):
- action (:obj:`torch.Tensor`): Continuous action tensor with same size as ``action_shape``.
- logit (:obj:`torch.Tensor`):
Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``.
Shapes:
- inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size``
- action (:obj:`torch.Tensor`): :math:`(B, N0)`
- logit (:obj:`list`): 2 elements, mu and sigma, each is the shape of :math:`(B, N0)`.
- q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, B is batch size.
Examples:
>>> # Regression mode
>>> model = QACDIST(64, 64, 'regression')
>>> inputs = torch.randn(4, 64)
>>> actor_outputs = model(inputs,'compute_actor')
>>> assert actor_outputs['action'].shape == torch.Size([4, 64])
>>> # Reparameterization Mode
>>> model = QACDIST(64, 64, 'reparameterization')
>>> inputs = torch.randn(4, 64)
>>> actor_outputs = model(inputs,'compute_actor')
>>> actor_outputs['logit'][0].shape # mu
>>> torch.Size([4, 64])
>>> actor_outputs['logit'][1].shape # sigma
>>> torch.Size([4, 64])
"""
x = self.actor(inputs)
if self.actor_head_type == 'regression':
return {'action': x['pred']}
elif self.actor_head_type == 'reparameterization':
return {'logit': [x['mu'], x['sigma']]}
def compute_critic(self, inputs: Dict) -> Dict:
r"""
Overview:
Execute parameter updates with ``'compute_critic'`` mode
Use encoded embedding tensor to predict output.
Arguments:
- ``obs``, ``action`` encoded tensors.
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Q-value output and distribution.
ReturnKeys:
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
- distribution (:obj:`torch.Tensor`): Q value distribution tensor.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``
- action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is``action_shape``
- q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
- distribution (:obj:`torch.FloatTensor`): :math:`(B, 1, N3)`, where B is batch size and N3 is ``num_atom``
Examples:
>>> # Categorical mode
>>> inputs = {'obs': torch.randn(4,N), 'action': torch.randn(4,1)}
>>> model = QACDIST(obs_shape=(N, ),action_shape=1,actor_head_type='regression', \
... critic_head_type='categorical', n_atoms=51)
>>> q_value = model(inputs, mode='compute_critic') # q value
>>> assert q_value['q_value'].shape == torch.Size([4, 1])
>>> assert q_value['distribution'].shape == torch.Size([4, 1, 51])
"""
obs, action = inputs['obs'], inputs['action']
assert len(obs.shape) == 2
if len(action.shape) == 1: # (B, ) -> (B, 1)
action = action.unsqueeze(1)
x = torch.cat([obs, action], dim=1)
x = self.critic(x)
return {'q_value': x['logit'], 'distribution': x['distribution']}
import torch
import numpy as np
import pytest
from itertools import product
from ding.model.template import QACDIST
from ding.torch_utils import is_differentiable
from ding.utils import squeeze
B = 4
T = 6
embedding_size = 32
action_shape_args = [(6, ), [
1,
]]
args = list(product(*[action_shape_args, ['regression', 'reparameterization']]))
@pytest.mark.unittest
@pytest.mark.parametrize('action_shape, actor_head_type', args)
class TestQACDIST:
def test_fcqac_dist(self, action_shape, actor_head_type):
N = 32
inputs = {'obs': torch.randn(B, N), 'action': torch.randn(B, squeeze(action_shape))}
model = QACDIST(
obs_shape=(N, ),
action_shape=action_shape,
actor_head_type=actor_head_type,
critic_head_hidden_size=embedding_size,
actor_head_hidden_size=embedding_size,
)
# compute_q
q = model(inputs, mode='compute_critic')
is_differentiable(q['q_value'].sum(), model.critic)
if isinstance(action_shape, int):
assert q['q_value'].shape == (B, 1)
assert q['distribution'].shape == (B, 1, 51)
elif len(action_shape) == 1:
assert q['q_value'].shape == (B, 1)
assert q['distribution'].shape == (B, 1, 51)
# compute_action
print(model)
if actor_head_type == 'regression':
action = model(inputs['obs'], mode='compute_actor')['action']
if squeeze(action_shape) == 1:
assert action.shape == (B, )
else:
assert action.shape == (B, squeeze(action_shape))
assert action.eq(action.clamp(-1, 1)).all()
is_differentiable(action.sum(), model.actor)
elif actor_head_type == 'reparameterization':
(mu, sigma) = model(inputs['obs'], mode='compute_actor')['logit']
assert mu.shape == (B, *action_shape)
assert sigma.shape == (B, *action_shape)
is_differentiable(mu.sum() + sigma.sum(), model.actor)
......@@ -5,6 +5,7 @@ from .qrdqn import QRDQNPolicy
from .c51 import C51Policy
from .rainbow import RainbowDQNPolicy
from .ddpg import DDPGPolicy
from .d4pg import D4PGPolicy
from .td3 import TD3Policy
from .a2c import A2CPolicy
from .ppo import PPOPolicy
......
......@@ -24,6 +24,7 @@ from .atoc import ATOCPolicy
from .acer import ACERPolicy
from .qtran import QTRANPolicy
from .sql import SQLPolicy
from .d4pg import D4PGPolicy
from .cql import CQLPolicy, CQLDiscretePolicy
......@@ -198,3 +199,7 @@ class ACERCommandModePolisy(ACERPolicy, DummyCommandModePolicy):
@POLICY_REGISTRY.register('qtran_command')
class QTRANCommandModePolicy(QTRANPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('d4pg_command')
class D4PGCommandModePolicy(D4PGPolicy, DummyCommandModePolicy):
pass
\ No newline at end of file
from typing import List, Dict, Any, Tuple, Union
import torch
import copy
from ding.torch_utils import Adam, to_device
from ding.rl_utils import get_train_sample
from ding.rl_utils import dist_nstep_td_data, dist_nstep_td_error, get_nstep_return_data
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY
from .ddpg import DDPGPolicy
from .common_utils import default_preprocess_learn
import numpy as np
@POLICY_REGISTRY.register('d4pg')
class D4PGPolicy(DDPGPolicy):
r"""
Overview:
Policy class of D4PG algorithm.
Property:
learn_mode, collect_mode, eval_mode
Config:
== ==================== ======== ============= ================================= =======================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ============= ================================= =======================
1 ``type`` str d4pg | RL policy register name, refer | this arg is optional,
| to registry ``POLICY_REGISTRY`` | a placeholder
2 ``cuda`` bool True | Whether to use cuda for network |
3 | ``random_`` int 25000 | Number of randomly collected | Default to 25000 for
| ``collect_size`` | training samples in replay | DDPG/TD3, 10000 for
| | buffer when training starts. | sac.
5 | ``learn.learning`` float 1e-3 | Learning rate for actor |
| ``_rate_actor`` | network(aka. policy). |
6 | ``learn.learning`` float 1e-3 | Learning rates for critic |
| ``_rate_critic`` | network (aka. Q-network). |
7 | ``learn.actor_`` int 1 | When critic network updates | Default 1
| ``update_freq`` | once, how many times will actor |
| | network update. |
8 | ``learn.noise`` bool False | Whether to add noise on target | Default False for
| | network's action. | D4PG.
| | | Target Policy Smoo-
| | | thing Regularization
| | | in TD3 paper.
9 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only
| ``ignore_done`` | done flag. | in halfcheetah env.
10 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation
| ``target_theta`` | target network. | factor in polyak aver
| | | aging for target
| | | networks.
11 | ``collect.-`` float 0.1 | Used for add noise during co- | Sample noise from dis
| ``noise_sigma`` | llection, through controlling | tribution, Gaussian
| | the sigma of distribution | process.
12 | ``model.v_min`` float -10 | Value of the smallest atom |
| | in the support set. |
13 | ``model.v_max`` float 10 | Value of the largest atom |
| | in the support set. |
14 | ``model.n_atom`` int 51 | Number of atoms in the support |
| | set of the value distribution. |
15 | ``nstep`` int 3, | N-step reward discount sum for |
| [1, 5] | target q_value estimation |
16 | ``priority`` bool True | Whether use priority(PER) | priority sample,
| update priority
== ==================== ======== ============= ================================= =======================
"""
config = dict(
# (str) RL policy register name (refer to function "POLICY_REGISTRY").
type='d4pg',
# (bool) Whether to use cuda for network.
cuda=False,
# (bool type) on_policy: Determine whether on-policy or off-policy.
# on-policy setting influences the behaviour of buffer.
# Default False in D4PG.
on_policy=False,
# (bool) Whether use priority(priority sample, IS weight, update priority)
# Default True in D4PG.
priority=True,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=True,
# (int) Number of training samples(randomly collected) in replay buffer when training starts.
# Default 25000 in D4PG.
random_collect_size=25000,
# (int) N-step reward for target q_value estimation
nstep=3,
model=dict(
# (float) Value of the smallest atom in the support set.
# Default to -10.0.
v_min=-10,
# (float) Value of the smallest atom in the support set.
# Default to 10.0.
v_max=10,
# (int) Number of atoms in the support set of the
# value distribution. Default to 51.
n_atom=51
),
learn=dict(
# (bool) Whether to use multi gpu
multi_gpu=False,
# How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=1,
# (int) Minibatch size for gradient descent.
batch_size=256,
# Learning rates for actor network(aka. policy).
learning_rate_actor=1e-3,
# Learning rates for critic network(aka. Q-network).
learning_rate_critic=1e-3,
# (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
# Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
# These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
# However, interaction with HalfCheetah always gets done with done is False,
# Since we inplace done==True with done==False to keep
# TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
# when the episode step is greater than max episode step.
ignore_done=False,
# (float type) target_theta: Used for soft update of the target network,
# aka. Interpolation factor in polyak averaging for target networks.
# Default to 0.005.
target_theta=0.005,
# (float) discount factor for the discounted sum of rewards, aka. gamma.
discount_factor=0.99,
# (int) When critic network updates once, how many times will actor network update.
actor_update_freq=1,
# (bool) Whether to add noise on target network's action.
# Target Policy Smoothing Regularization in original TD3 paper.
noise=False,
),
collect=dict(
# (int) Only one of [n_sample, n_episode] should be set
n_sample=1,
# (int) Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
# It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma".
noise_sigma=0.1,
),
eval=dict(evaluator=dict(eval_freq=1000, ), ),
other=dict(
replay_buffer=dict(
# (int) Maximum size of replay buffer.
replay_buffer_size=1000000,
),
),
)
def _init_learn(self) -> None:
r"""
Overview:
Learn mode init method. Called by ``self.__init__``.
Init actor and critic optimizers, algorithm config, main and target models.
"""
self._priority = self._cfg.priority
self._priority_IS_weight = self._cfg.priority_IS_weight
# actor and critic optimizer
self._optimizer_actor = Adam(
self._model.actor.parameters(),
lr=self._cfg.learn.learning_rate_actor,
)
self._optimizer_critic = Adam(
self._model.critic.parameters(),
lr=self._cfg.learn.learning_rate_critic,
)
self._use_reward_batch_norm = self._cfg.get('use_reward_batch_norm', False)
self._gamma = self._cfg.learn.discount_factor
self._nstep = self._cfg.nstep
self._actor_update_freq = self._cfg.learn.actor_update_freq
# main and target models
self._target_model = copy.deepcopy(self._model)
self._target_model = model_wrap(
self._target_model,
wrapper_name='target',
update_type='momentum',
update_kwargs={'theta': self._cfg.learn.target_theta}
)
if self._cfg.learn.noise:
self._target_model = model_wrap(
self._target_model,
wrapper_name='action_noise',
noise_type='gauss',
noise_kwargs={
'mu': 0.0,
'sigma': self._cfg.learn.noise_sigma
},
noise_range=self._cfg.learn.noise_range
)
self._learn_model = model_wrap(self._model, wrapper_name='base')
self._learn_model.reset()
self._target_model.reset()
self._v_max = self._cfg.model.v_max
self._v_min = self._cfg.model.v_min
self._n_atom = self._cfg.model.n_atom
self._forward_learn_cnt = 0 # count iterations
def _forward_learn(self, data: dict) -> Dict[str, Any]:
r"""
Overview:
Forward and backward function of learn mode.
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including at least actor and critic lr, different losses.
"""
loss_dict = {}
data = default_preprocess_learn(
data,
use_priority=self._cfg.priority,
use_priority_IS_weight=self._cfg.priority_IS_weight,
ignore_done=self._cfg.learn.ignore_done,
use_nstep=True
)
if self._cuda:
data = to_device(data, self._device)
# ====================
# critic learn forward
# ====================
self._learn_model.train()
self._target_model.train()
next_obs = data.get('next_obs')
reward = data.get('reward')
if self._use_reward_batch_norm:
reward = (reward - reward.mean()) / (reward.std() + 1e-8)
# current q value
q_value = self._learn_model.forward(data, mode='compute_critic')
q_value_dict = {}
q_dist = q_value['distribution']
q_value_dict['q_value'] = q_value['q_value'].mean()
# target q value. SARSA: first predict next action, then calculate next q value
with torch.no_grad():
next_action = self._target_model.forward(next_obs, mode='compute_actor')['action']
next_data = {'obs': next_obs, 'action': next_action}
target_q_dist = self._target_model.forward(next_data, mode='compute_critic')['distribution']
value_gamma = data.get('value_gamma')
action_index = np.zeros(next_action.shape[0])
# since the action is a scalar value, action index is set to 0 which is the only possible choice
td_data = dist_nstep_td_data(
q_dist, target_q_dist, action_index, action_index, reward, data['done'], data['weight']
)
critic_loss, td_error_per_sample = dist_nstep_td_error(
td_data, self._gamma, self._v_min, self._v_max, self._n_atom, nstep=self._nstep, value_gamma=value_gamma
)
loss_dict['critic_loss'] = critic_loss
# ================
# critic update
# ================
self._optimizer_critic.zero_grad()
for k in loss_dict:
if 'critic' in k:
loss_dict[k].backward()
self._optimizer_critic.step()
# ===============================
# actor learn forward and update
# ===============================
# actor updates every ``self._actor_update_freq`` iters
if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0:
actor_data = self._learn_model.forward(data['obs'], mode='compute_actor')
actor_data['obs'] = data['obs']
actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'].mean()
loss_dict['actor_loss'] = actor_loss
# actor update
self._optimizer_actor.zero_grad()
actor_loss.backward()
self._optimizer_actor.step()
# =============
# after update
# =============
loss_dict['total_loss'] = sum(loss_dict.values())
self._forward_learn_cnt += 1
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr_actor': self._optimizer_actor.defaults['lr'],
'cur_lr_critic': self._optimizer_critic.defaults['lr'],
'q_value': np.array(q_value['q_value'].detach().numpy()).mean(),
'action': data.get('action').mean(),
'priority': td_error_per_sample.abs().tolist(),
**loss_dict,
**q_value_dict,
}
def _get_train_sample(self, traj: list) -> Union[None, List[Any]]:
r"""
Overview:
Get the trajectory and the n step return data, then sample from the n_step return data
Arguments:
- traj (:obj:`list`): The trajectory's buffer list
Returns:
- samples (:obj:`dict`): The training samples generated
"""
data = get_nstep_return_data(traj, self._nstep, gamma=self._gamma)
return get_train_sample(data, self._unroll_len)
def default_model(self) -> Tuple[str, List[str]]:
return 'qac_dist', ['ding.model.template.qac_dist']
def _monitor_vars_learn(self) -> List[str]:
r"""
Overview:
Return variables' name if variables are to used in monitor.
Returns:
- vars (:obj:`List[str]`): Variables' name list.
"""
ret = ['cur_lr_actor', 'cur_lr_critic', 'critic_loss', 'actor_loss', 'total_loss', 'q_value', 'action']
return ret
import os
import gym
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import C51Policy
from ding.model import C51DQN
from ding.utils import set_pkg_seed
from ding.rl_utils import get_epsilon_greedy_fn
from dizoo.classic_control.cartpole.config.cartpole_c51_config import cartpole_c51_config
# Get DI-engine form env class
def wrapped_cartpole_env():
return DingEnvWrapper(gym.make('CartPole-v0'))
def main(cfg, seed=0):
cfg = compile_config(
cfg,
BaseEnvManager,
C51Policy,
BaseLearner,
SampleSerialCollector,
InteractionSerialEvaluator,
AdvancedReplayBuffer,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(collector_env_num)], cfg=cfg.env.manager)
evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
# Set random seed for all package and instance
collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
# Set up RL Policy
model = C51DQN(**cfg.policy.model)
policy = C51Policy(cfg.policy, model=model)
# Set up collection, training and evaluation utilities
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
# Set up other modules, etc. epsilon greedy
eps_cfg = cfg.policy.other.eps
epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
# Training & Evaluation loop
while True:
# Evaluating at the beginning and with specific frequency
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Update other modules
eps = epsilon_greedy(collector.envstep)
# Sampling data from environments
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# Training
for i in range(cfg.policy.learn.update_per_collect):
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is None:
break
learner.train(train_data, collector.envstep)
# evaluate
evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
evaluator_env.enable_save_replay(cfg.env.replay_path) # switch save replay interface
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if __name__ == "__main__":
main(cartpole_c51_config)
from .pendulum_ddpg_config import pendulum_ddpg_config, pendulum_ddpg_create_config
from .pendulum_td3_config import pendulum_td3_config, pendulum_td3_create_config
from .pendulum_sac_config import pendulum_sac_config, pendulum_sac_create_config
from .pendulum_d4pg_config import pendulum_d4pg_config, pendulum_d4pg_create_config
from easydict import EasyDict
pendulum_d4pg_config = dict(
env=dict(
collector_env_num=8,
evaluator_env_num=5,
# (bool) Scale output action into legal range.
act_scale=True,
n_evaluator_episode=5,
stop_value=-250,
),
policy=dict(
cuda=False,
priority=True,
nstep=3,
random_collect_size=800,
model=dict(
obs_shape=3,
action_shape=1,
actor_head_type='regression',
v_min=-100,
v_max=100,
n_atom=51,
),
learn=dict(
update_per_collect=2,
batch_size=128,
learning_rate_actor=0.001,
learning_rate_critic=0.001,
ignore_done=True,
actor_update_freq=1,
noise=False,
),
collect=dict(
n_sample=48,
noise_sigma=0.1,
collector=dict(collect_print_freq=1000, ),
),
eval=dict(evaluator=dict(eval_freq=100, ), ),
other=dict(replay_buffer=dict(
replay_buffer_size=20000,
max_use=16,
), ),
),
)
pendulum_d4pg_config = EasyDict(pendulum_d4pg_config)
main_config = pendulum_d4pg_config
pendulum_d4pg_create_config = dict(
env=dict(
type='pendulum',
import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
),
env_manager=dict(type='base'),
policy=dict(type='d4pg'),
)
pendulum_d4pg_create_config = EasyDict(pendulum_d4pg_create_config)
create_config = pendulum_d4pg_create_config
import os
import gym
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
from ding.envs import BaseEnvManager
from ding.policy.d4pg import D4PGPolicy
from ding.model.template import QACDIST
from ding.utils import set_pkg_seed
from dizoo.classic_control.pendulum.envs import PendulumEnv
from dizoo.classic_control.pendulum.config.pendulum_d4pg_config import pendulum_d4pg_config
def main(cfg, seed=0):
cfg = compile_config(
cfg,
BaseEnvManager,
D4PGPolicy,
BaseLearner,
SampleSerialCollector,
InteractionSerialEvaluator,
AdvancedReplayBuffer,
save_cfg=True
)
# Set up envs for collection and evaluation
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(
env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = BaseEnvManager(
env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
)
# Set random seed for all package and instance
collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
# Set up RL Policy
model = QACDIST(**cfg.policy.model)
policy = D4PGPolicy(cfg.policy, model=model)
# Set up collection, training and evaluation utilities
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
# Training & Evaluation loop
while True:
# Evaluate at the beginning and with specific frequency
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Collect data from environments
new_data = collector.collect(train_iter=learner.train_iter)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# Train
for i in range(cfg.policy.learn.update_per_collect):
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is None:
break
learner.train(train_data, collector.envstep)
replay_buffer.update(learner.priority_info)
if __name__ == "__main__":
main(pendulum_d4pg_config, seed=0)
import os
import gym
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import DDPGPolicy
from ding.model import QAC
from ding.utils import set_pkg_seed
from dizoo.classic_control.pendulum.envs import PendulumEnv
from dizoo.classic_control.pendulum.config.pendulum_ddpg_config import pendulum_ddpg_config
def main(cfg, seed=0):
cfg = compile_config(
cfg,
BaseEnvManager,
DDPGPolicy,
BaseLearner,
SampleSerialCollector,
InteractionSerialEvaluator,
AdvancedReplayBuffer,
save_cfg=True
)
# Set up envs for collection and evaluation
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(
env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = BaseEnvManager(
env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
)
# Set random seed for all package and instance
collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
# Set up RL Policy
model = QAC(**cfg.policy.model)
policy = DDPGPolicy(cfg.policy, model=model)
# Set up collection, training and evaluation utilities
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
# Training & Evaluation loop
while True:
# Evaluate at the beginning and with specific frequency
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Collect data from environments
new_data = collector.collect(train_iter=learner.train_iter)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# Train
for i in range(cfg.policy.learn.update_per_collect):
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is None:
break
learner.train(train_data, collector.envstep)
if __name__ == "__main__":
main(pendulum_ddpg_config, seed=0)
from .hopper_cql_default_config import hopper_cql_default_config
from .hopper_expert_cql_default_config import hopper_expert_cql_default_config
from .hopper_medium_cql_default_config import hopper_medium_cql_default_config
\ No newline at end of file
from .hopper_medium_cql_default_config import hopper_medium_cql_default_config
......@@ -369,7 +369,7 @@ class D4RLEnv(BaseEnv):
info.rew_space.shape = rew_shape
return info
else:
raise NotImplementedError('{} not found in D4RL_INFO_DICT [{}]' \
raise NotImplementedError('{} not found in D4RL_INFO_DICT [{}]'
.format(self._cfg.env_id, D4RL_INFO_DICT.keys()))
def _make_env(self, only_info=False):
......
from easydict import EasyDict
hopper_d4pg_default_config = dict(
env=dict(
env_id='Hopper-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=3000,
),
policy=dict(
cuda=True,
priority=True,
nstep=5,
on_policy=False,
random_collect_size=25000,
model=dict(
obs_shape=11,
action_shape=3,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
critic_head_type='categorical',
v_min=-100,
v_max=100,
n_atom=51,
),
learn=dict(
update_per_collect=1,
batch_size=256,
learning_rate_actor=1e-3,
learning_rate_critic=1e-3,
ignore_done=False,
target_theta=0.005,
discount_factor=0.99,
actor_update_freq=1,
noise=False,
),
collect=dict(
n_sample=1,
unroll_len=1,
noise_sigma=0.1,
),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
)
)
hopper_d4pg_default_config = EasyDict(hopper_d4pg_default_config)
main_config = hopper_d4pg_default_config
hopper_d4pg_default_create_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='d4pg',
import_names=['ding.policy.d4pg'],
),
)
hopper_d4pg_default_create_config = EasyDict(hopper_d4pg_default_create_config)
create_config = hopper_d4pg_default_create_config
import os
import gym
from tensorboardX import SummaryWriter
from easydict import EasyDict
from ding.config import compile_config
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import D4PGPolicy
from ding.model.template.qac_dist import QACDIST
from ding.utils import set_pkg_seed
from dizoo.classic_control.pendulum.envs import PendulumEnv
from dizoo.mujoco.envs.mujoco_env import MujocoEnv
from dizoo.classic_control.pendulum.config.pendulum_ppo_config import pendulum_ppo_config
from dizoo.mujoco.config.hopper_d4pg_default_config import hopper_d4pg_default_config
def main(cfg, seed=0, max_iterations=int(1e10)):
cfg = compile_config(
cfg,
BaseEnvManager,
D4PGPolicy,
BaseLearner,
SampleSerialCollector,
InteractionSerialEvaluator,
AdvancedReplayBuffer,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(
env_fn=[lambda: MujocoEnv(cfg.env) for _ in range(collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = BaseEnvManager(
env_fn=[lambda: MujocoEnv(cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
)
collector_env.seed(seed, dynamic_seed=True)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
model = QACDIST(**cfg.policy.model)
policy = D4PGPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./log/', 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger)
collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger)
evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
for _ in range(max_iterations):
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Collect data from environments
new_data = collector.collect(train_iter=learner.train_iter)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# Train
for i in range(cfg.policy.learn.update_per_collect):
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is None:
break
learner.train(train_data, collector.envstep)
replay_buffer.update(learner.priority_info)
if __name__ == "__main__":
main(hopper_d4pg_default_config)
import os
import gym
from tensorboardX import SummaryWriter
from easydict import EasyDict
from ding.config import compile_config
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import DDPGPolicy
from ding.model import QAC
from ding.utils import set_pkg_seed
from dizoo.classic_control.pendulum.envs import PendulumEnv
from dizoo.mujoco.envs.mujoco_env import MujocoEnv
from dizoo.classic_control.pendulum.config.pendulum_ppo_config import pendulum_ppo_config
from dizoo.mujoco.config.hopper_ddpg_default_config import hopper_ddpg_default_config
def main(cfg, seed=0, max_iterations=int(1e10)):
cfg = compile_config(
cfg,
BaseEnvManager,
DDPGPolicy,
BaseLearner,
SampleSerialCollector,
InteractionSerialEvaluator,
AdvancedReplayBuffer,
save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(
env_fn=[lambda: MujocoEnv(cfg.env) for _ in range(collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = BaseEnvManager(
env_fn=[lambda: MujocoEnv(cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
)
collector_env.seed(seed, dynamic_seed=True)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
model = QAC(**cfg.policy.model)
policy = DDPGPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./log/', 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger)
collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger)
evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
for _ in range(max_iterations):
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Collect data from environments
new_data = collector.collect(train_iter=learner.train_iter)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# Train
for i in range(cfg.policy.learn.update_per_collect):
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is None:
break
learner.train(train_data, collector.envstep)
if __name__ == "__main__":
main(hopper_ddpg_default_config)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册