TD3¶
Overview¶
Twin Delayed DDPG (TD3), proposed in the 2018 paper Addressing Function Approximation Error in Actor-Critic Methods, is an algorithm which considers the interplay between function approximation error in both policy and value updates. TD3 is an actor-critic, model-free algorithm based on the deep deterministic policy gradient(DDPG) that can address overestimation bias and the accumulation of error in temporal difference methods in continuous action spaces.
Quick Facts¶
TD3 is only used for environments with continuous action spaces.(i.e. MuJoCo)
TD3 is an off-policy algorithm.
TD3 is a model-free and actor-critic RL algorithm, which optimizes actor network and critic network, respectively.
Key Equations or Key Graphs¶
TD3 proposes a clipped Double Q-learning variant which leverages the notion that a value estimate suffering from overestimation bias can be used as an approximate upper-bound to the true value estimate.
First, TD3 shows that target networks, a common approach in deep Q-learning methods, are critical for variance reduction by reducing the accumulation of errors.
Second, to address the coupling of value and policy, TD3 proposes delaying policy updates until the value estimate has converged.
Finally, TD3 introduces a novel regularization strategy(Target Policy Smoothing Regularization), where a SARSA-style update bootstraps similar action estimates to further reduce variance.
The target update of Clipped Double Q-learning algorithm:
In implementation, computational costs can be reduced by using a single actor optimized with respect to \(Q_{\theta_1}\) . We then use the same target \(y_2= y_1for Q_{\theta_2}\).
A concern with deterministic policies is they can overfit to narrow peaks in the value estimate. When updating the critic, a learning target using a deterministic policy is highly susceptible to inaccuracies induced by function approximation error, increasing the variance of the target. TD3 introduces a regularization strategy for deep value learning, target policy smoothing, which mimics the learning update from SARSA. Specifically, TD3 approximates this expectation over actions by adding a small amount of random noise to the target policy and averaging over mini-batches following:
Pseudocode¶
Extensions¶
- TD3 can be combined with:
Target Network.
Addressing Function Approximation Error in Actor-Critic Methods uses soft update Target Network to ensure the TD-error remains small. Since we implement soft update Target Network for actor-critic through
TargetNetworkWrapper
inmodel_wrap
and configuringlearn.target_theta
.Policy Updates Delay
Addressing Function Approximation Error in Actor-Critic Methods proposes delaying policy updates until the value error is as small as possible. Therefore, TD3 only updates the policy and target networks after a fixed number of updates \(d\) to the critic. Since we implement Policy Updates Delay through configuring
learn.target_theta
.Target Policy Smoothing
Addressing Function Approximation Error in Actor-Critic Methods proposes Target Policy Smoothing Regularization to reduce variance from deterministic policies. Since we implement Target Policy Smoothing through configuring
learn.noise
,learn.noise_sigma
, andlearn.noise_range
.Clipped Double-Q Learning
Addressing Function Approximation Error in Actor-Critic Methods proposes Clipped Double Q-learning, which greatly reduces overestimation by the critic. Since we implement Clipped Double-Q Learning through configuring
learn.actor_update_freq
.Replay Buffers
DDPG/TD3 random-collect-size is set to 25000 by default, while it is 25000 for SAC. We only simply follow SpinningUp default setting and use random policy to collect initialization data. We configure
random_collect_size
for data collection.
Implementations¶
The default config is defined as follows:
- class ding.policy.td3.TD3Policy(cfg: dict, model: Optional[Union[type, torch.nn.modules.module.Module]] = None, enable_field: Optional[List[str]] = None)[source]¶
- Overview:
Policy class of TD3 algorithm.
Since DDPG and TD3 share many common things, we can easily derive this TD3 class from DDPG class by changing
_actor_update_freq
,_twin_critic
and noise in model wrapper.- Property:
learn_mode, collect_mode, eval_mode
Config:
ID
Symbol
Type
Default Value
Description
Other(Shape)
1
type
str
td3
RL policy register name, referto registryPOLICY_REGISTRY
this arg is optional,a placeholder2
cuda
bool
True
Whether to use cuda for network3
random_
collect_size
int
25000
Number of randomly collectedtraining samples in replaybuffer when training starts.Default to 25000 forDDPG/TD3, 10000 forsac.4
model.twin_
critic
bool
True
Whether to use two criticnetworks or only one.Default True for TD3,Clipped DoubleQ-learning method inTD3 paper.5
learn.learning
_rate_actor
float
1e-3
Learning rate for actornetwork(aka. policy).6
learn.learning
_rate_critic
float
1e-3
Learning rates for criticnetwork (aka. Q-network).7
learn.actor_
update_freq
int
2
When critic network updatesonce, how many times will actornetwork update.Default 2 for TD3, 1for DDPG. DelayedPolicy Updates methodin TD3 paper.8
learn.noise
bool
True
Whether to add noise on targetnetwork’s action.Default True for TD3,False for DDPG.Target Policy Smoo-thing Regularizationin TD3 paper.9
learn.noise_
range
dict
dict(min=-0.5,max=0.5,)Limit for range of targetpolicy smoothing noise,aka. noise_clip.10
learn.-
ignore_done
bool
False
Determine whether to ignoredone flag.Use ignore_done onlyin halfcheetah env.11
learn.-
target_theta
float
0.005
Used for soft update of thetarget network.aka. Interpolationfactor in polyak averaging for targetnetworks.12
collect.-
noise_sigma
float
0.1
Used for add noise during co-llection, through controllingthe sigma of distributionSample noise from distribution, Ornstein-Uhlenbeck process inDDPG paper, Guassianprocess in ours.
Model¶
Here we provide examples of td3 model as default model for TD3.
- class ding.model.template.qac.QAC(obs_shape: Union[int, ding.utils.type_helper.SequenceType], action_shape: Union[int, ding.utils.type_helper.SequenceType, easydict.EasyDict], actor_head_type: str, twin_critic: bool = False, actor_head_hidden_size: int = 64, actor_head_layer_num: int = 1, critic_head_hidden_size: int = 64, critic_head_layer_num: int = 1, activation: Optional[torch.nn.modules.module.Module] = ReLU(), norm_type: Optional[str] = None)[source]
- Overview:
The QAC model.
- Interfaces:
__init__
,forward
,compute_actor
,compute_critic
- __init__(obs_shape: Union[int, ding.utils.type_helper.SequenceType], action_shape: Union[int, ding.utils.type_helper.SequenceType, easydict.EasyDict], actor_head_type: str, twin_critic: bool = False, actor_head_hidden_size: int = 64, actor_head_layer_num: int = 1, critic_head_hidden_size: int = 64, critic_head_layer_num: int = 1, activation: Optional[torch.nn.modules.module.Module] = ReLU(), norm_type: Optional[str] = None) None [source]
- Overview:
Init the QAC Model according to arguments.
- Arguments:
obs_shape (
Union[int, SequenceType]
): Observation’s space.- action_shape (
Union[int, SequenceType, EasyDict]
): Action’s space, such as 4, (3, ), EasyDict({‘action_type_shape’: 3, ‘action_args_shape’: 4}).
- action_shape (
actor_head_type (
str
): Whether chooseregression
orreparameterization
orhybrid
.twin_critic (
bool
): Whether include twin critic.actor_head_hidden_size (
Optional[int]
): Thehidden_size
to pass to actor-nn’sHead
.- actor_head_layer_num (
int
): The num of layers used in the network to compute Q value output for actor’s nn.
- actor_head_layer_num (
critic_head_hidden_size (
Optional[int]
): Thehidden_size
to pass to critic-nn’sHead
.- critic_head_layer_num (
int
): The num of layers used in the network to compute Q value output for critic’s nn.
- critic_head_layer_num (
- activation (
Optional[nn.Module]
): The type of activation function to use in
MLP
the afterlayer_fn
, ifNone
then default set tonn.ReLU()
- activation (
- norm_type (
Optional[str]
): The type of normalization to use, see
ding.torch_utils.fc_block
for more details.
- norm_type (
- compute_actor(inputs: torch.Tensor) Dict [source]
- Overview:
Use encoded embedding tensor to predict output. Execute parameter updates with
'compute_actor'
mode Use encoded embedding tensor to predict output.- Arguments:
- inputs (
torch.Tensor
): The encoded embedding tensor, determined with given
hidden_size
, i.e.(B, N=hidden_size)
.hidden_size = actor_head_hidden_size
- inputs (
mode (
str
): Name of the forward mode.
- Returns:
outputs (
Dict
): Outputs of forward pass encoder and head.
- ReturnsKeys (either):
action (
torch.Tensor
): Continuous action tensor with same size asaction_shape
.- logit (
torch.Tensor
): Logit tensor encoding
mu
andsigma
, both with same size as inputx
.
- logit (
logit + action_args
- Shapes:
inputs (
torch.Tensor
): \((B, N0)\), B is batch size and N0 corresponds tohidden_size
action (
torch.Tensor
): \((B, N0)\)logit (
Union[list, torch.Tensor]
): - case1(continuous space, list): 2 elements, mu and sigma, each is the shape of \((B, N0)\). - case2(hybrid space, torch.Tensor): \((B, N1)\), where N1 is action_type_shapeq_value (
torch.FloatTensor
): \((B, )\), B is batch size.- action_args (
torch.FloatTensor
): \((B, N2)\), where N2 is action_args_shape (action_args are continuous real value)
- action_args (
- Examples:
>>> # Regression mode >>> model = QAC(64, 64, 'regression') >>> inputs = torch.randn(4, 64) >>> actor_outputs = model(inputs,'compute_actor') >>> assert actor_outputs['action'].shape == torch.Size([4, 64]) >>> # Reparameterization Mode >>> model = QAC(64, 64, 'reparameterization') >>> inputs = torch.randn(4, 64) >>> actor_outputs = model(inputs,'compute_actor') >>> actor_outputs['logit'][0].shape # mu >>> torch.Size([4, 64]) >>> actor_outputs['logit'][1].shape # sigma >>> torch.Size([4, 64])
- compute_critic(inputs: Dict) Dict [source]
- Overview:
Execute parameter updates with
'compute_critic'
mode Use encoded embedding tensor to predict output.- Arguments:
inputs (:obj: Dict):
obs
,action
and ``logit` tensors.mode (
str
): Name of the forward mode.
- Returns:
outputs (
Dict
): Q-value output.
- ArgumentsKeys:
necessary: - obs: (
torch.Tensor
): 2-dim vector observation - action (Union[torch.Tensor, Dict]
): action from actoroptional: - logit (
torch.Tensor
): discrete action logit
- ReturnKeys:
q_value (
torch.Tensor
): Q value tensor with same size as batch size.
- Shapes:
obs (
torch.Tensor
): \((B, N1)\), where B is batch size and N1 isobs_shape
action (
torch.Tensor
): \((B, N2)\), where B is batch size and N2 isaction_shape
q_value (
torch.FloatTensor
): \((B, )\), where B is batch size.
- Examples:
>>> inputs = {'obs': torch.randn(4, N), 'action': torch.randn(4, 1)} >>> model = QAC(obs_shape=(N, ),action_shape=1,actor_head_type='regression') >>> model(inputs, mode='compute_critic')['q_value'] # q value >>> tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
- forward(inputs: Union[torch.Tensor, Dict], mode: str) Dict [source]
- Overview:
Use observation and action tensor to predict output. Parameter updates with QAC’s MLPs forward setup.
- Arguments:
- Forward with
'compute_actor'
: - inputs (
torch.Tensor
): The encoded embedding tensor, determined with given
hidden_size
, i.e.(B, N=hidden_size)
. Whetheractor_head_hidden_size
orcritic_head_hidden_size
depend onmode
.
- inputs (
- Forward with
'compute_critic'
, inputs (Dict) Necessary Keys: obs
,action
encoded tensors.
mode (
str
): Name of the forward mode.
- Forward with
- Returns:
outputs (
Dict
): Outputs of network forward.- Forward with
'compute_actor'
, Necessary Keys (either): action (
torch.Tensor
): Action tensor with same size as inputx
.- logit (
torch.Tensor
): Logit tensor encoding
mu
andsigma
, both with same size as inputx
.
- logit (
- Forward with
'compute_critic'
, Necessary Keys: q_value (
torch.Tensor
): Q value tensor with same size as batch size.
- Forward with
- Actor Shapes:
inputs (
torch.Tensor
): \((B, N0)\), B is batch size and N0 corresponds tohidden_size
action (
torch.Tensor
): \((B, N0)\)q_value (
torch.FloatTensor
): \((B, )\), where B is batch size.
- Critic Shapes:
obs (
torch.Tensor
): \((B, N1)\), where B is batch size and N1 isobs_shape
action (
torch.Tensor
): \((B, N2)\), where B is batch size and N2 is``action_shape``logit (
torch.FloatTensor
): \((B, N2)\), where B is batch size and N3 isaction_shape
- Actor Examples:
>>> # Regression mode >>> model = QAC(64, 64, 'regression') >>> inputs = torch.randn(4, 64) >>> actor_outputs = model(inputs,'compute_actor') >>> assert actor_outputs['action'].shape == torch.Size([4, 64]) >>> # Reparameterization Mode >>> model = QAC(64, 64, 'reparameterization') >>> inputs = torch.randn(4, 64) >>> actor_outputs = model(inputs,'compute_actor') >>> actor_outputs['logit'][0].shape # mu >>> torch.Size([4, 64]) >>> actor_outputs['logit'][1].shape # sigma >>> torch.Size([4, 64])
- Critic Examples:
>>> inputs = {'obs': torch.randn(4,N), 'action': torch.randn(4,1)} >>> model = QAC(obs_shape=(N, ),action_shape=1,actor_head_type='regression') >>> model(inputs, mode='compute_critic')['q_value'] # q value tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
Train actor-critic model¶
First, we initialize actor and critic optimizer in _init_learn
, respectively.
Setting up two separate optimizers can guarantee that we only update actor network parameters and not critic network when we compute actor loss, vice versa.
# actor and critic optimizer self._optimizer_actor = Adam( self._model.actor.parameters(), lr=self._cfg.learn.learning_rate_actor, weight_decay=self._cfg.learn.weight_decay ) self._optimizer_critic = Adam( self._model.critic.parameters(), lr=self._cfg.learn.learning_rate_critic, weight_decay=self._cfg.learn.weight_decay )
- In
_forward_learn
we update actor-critic policy through computing critic loss, updating critic network, computing actor loss, and updating actor network. critic loss computation
current and target value computation
# current q value q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] q_value_dict = {} if self._twin_critic: q_value_dict['q_value'] = q_value[0].mean() q_value_dict['q_value_twin'] = q_value[1].mean() else: q_value_dict['q_value'] = q_value.mean() # target q value. SARSA: first predict next action, then calculate next q value with torch.no_grad(): next_action = self._target_model.forward(next_obs, mode='compute_actor')['action'] next_data = {'obs': next_obs, 'action': next_action} target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value']
target(Clipped Double-Q Learning) and loss computation
if self._twin_critic: # TD3: two critic networks target_q_value = torch.min(target_q_value[0], target_q_value[1]) # find min one as target q value # network1 td_data = v_1step_td_data(q_value[0], target_q_value, reward, data['done'], data['weight']) critic_loss, td_error_per_sample1 = v_1step_td_error(td_data, self._gamma) loss_dict['critic_loss'] = critic_loss # network2(twin network) td_data_twin = v_1step_td_data(q_value[1], target_q_value, reward, data['done'], data['weight']) critic_twin_loss, td_error_per_sample2 = v_1step_td_error(td_data_twin, self._gamma) loss_dict['critic_twin_loss'] = critic_twin_loss td_error_per_sample = (td_error_per_sample1 + td_error_per_sample2) / 2 else: # DDPG: single critic network td_data = v_1step_td_data(q_value, target_q_value, reward, data['done'], data['weight']) critic_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma) loss_dict['critic_loss'] = critic_loss
critic network update
self._optimizer_critic.zero_grad() for k in loss_dict: if 'critic' in k: loss_dict[k].backward() self._optimizer_critic.step()
actor loss
andactor network update
depending on the level of delaying the policy updates.
if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0: actor_data = self._learn_model.forward(data['obs'], mode='compute_actor') actor_data['obs'] = data['obs'] if self._twin_critic: actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0].mean() else: actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'].mean() loss_dict['actor_loss'] = actor_loss # actor update self._optimizer_actor.zero_grad() actor_loss.backward() self._optimizer_actor.step()
Target Network¶
We implement Target Network trough target model initialization in _init_learn
.
We configure learn.target_theta
to control the interpolation factor in averaging.
# main and target models
self._target_model = copy.deepcopy(self._model)
self._target_model = model_wrap(
self._target_model,
wrapper_name='target',
update_type='momentum',
update_kwargs={'theta': self._cfg.learn.target_theta}
)
Target Policy Smoothing Regularization¶
We implement Target Policy Smoothing Regularization trough target model initialization in _init_learn
.
We configure learn.noise
, learn.noise_sigma
, and learn.noise_range
to control the added noise, which is clipped to keep the target close to the original action.
if self._cfg.learn.noise:
self._target_model = model_wrap(
self._target_model,
wrapper_name='action_noise',
noise_type='gauss',
noise_kwargs={
'mu': 0.0,
'sigma': self._cfg.learn.noise_sigma
},
noise_range=self._cfg.learn.noise_range
)
The Benchmark result of TD3 implemented in DI-engine is shown in Benchmark
Other Public Implementations¶
References¶
Scott Fujimoto, Herke van Hoof, David Meger: “Addressing Function Approximation Error in Actor-Critic Methods”, 2018; [http://arxiv.org/abs/1802.09477 arXiv:1802.09477].