ACER

Overview

ACER, short for actor-critic with experience replay, is an off-policy actor-critic algorithm with experience replay. It greatly increases the sample efficiency and decreases the data correlation. ACER uses retrace Q-value estimation, an efficient TRPO (Trust Region Policy Optimization) and truncates importance sample weights with bias correction to control the stability of the off-policy estimator. You can find more details in this paper Sample Efficient Actor-Critic with Experience Replay.

Quick Facts

  1. ACER is a model-free and off-policy RL algorithm.

  2. ACER can support both discrete action spaces and continuous action spaces with several differences.

  3. ACER is an actor-critic RL algorithm, which optimizes the actor and critic networks respectively.

  4. ACER decouples acting from learning. Collectors in ACER needs to record behavior probability distributions.

Key Equations

Loss used in ACER contains policy loss and value loss. They often update separately, so it’s necessary to control their relative update speed.

Retrace Q-value estimation

Given a trajectory generated under the behavior policy \(\mu\), we retrieve a trajectory \({x_0, a_0, r_0, \mu(\cdot|x_0),..., x_k, a_k, r_k, \mu(\cdot|x_k)}\) the Retrace estimator can be expressed recursively as follows:

\[Q^{\text{ret}}(x_t,a_t)=r_t+\gamma\bar{\rho}_{t+1}[Q^{\text{ret}}(x_{t+ 1},a_{t+1})]+\gamma V(x_{t+1})\]

where \(\bar{\rho}_t\) is the truncated importance weight, \(\bar{\rho}_t=\min\{c,\rho\}\) with \(\frac{\pi(a_t|x_t)}{\mu(a_t|x_t)}\). \(\pi\) is the target policy. Retrace is an off-policy, return based algorithm which has low variance and is proven to converge to the value function of the target policy for any behavior policy. We approximate the Q value by neural network \(Q_{\theta}\). We use a mean squared error loss:

\[L_{\text{value}}=\frac{1}{2}(Q^{\text{ret}}(x_t,a_t)-Q_{\theta}(x_t,a_t))^2.\]

policy gradient

To safe-guard against high variance, ACER uses truncated importance weights and introduces correction term via the following decomposition of \(g^{acer}\):

\[g^{\text{acer}}=\bar{\rho_t}\nabla_\theta\log\pi_{\theta}(a_t|x_t)[Q^{\text{ret}}(x_t,a_t)-V_{\theta}(x_t)]+\mathbb{E}_{a\sim \pi}\left([\frac{\rho_t(a)-c}{\rho_t(a)}]_+\nabla_{\theta}\log\pi_{\theta}(a|x_t)[Q_\theta(x_t,a)-V_{\theta}(x_t)\right)\right].\]

To ensure more stability, ACER limits the per-step change to the policy by solving the following linearized KL divergence constraint:

\[\begin{split}\begin{split} &\text{minimize}_z\quad\frac{1}{2}\|g_t^{\text{acer}}-z\|_2^2\\ &subjec\ to\quad \nabla_{\phi_{\theta}(x_t)}D_{KL}[f(\cdot|\phi_{\theta_a}(x_t))\|f(\cdot|\phi_{\theta}(x_t))]^\top\le\delta \end{split}\end{split}\]

The \(\phi(\theta)\) is the target policy network and the \(\phi(\theta_a)\) is the average policy network. By letting \(k=\nabla_{\phi_{\theta}(x_t)}D_{KL}[f(\cdot|\phi_{\theta_a}(x_t))\|f(\cdot|\phi_{\theta}(x_t))]\), the solution can be easily derived in closed form using the KKT condition:

\[z^*=g_{t}^{\text{acer}}-\max\{0,\frac{k^\top g_t^{\text{acer}}-\delta}{\|k\|_2^2}\}k\]

Pseudocode

There are a few changes between ACER applied to discrete action spaces and that applied to continuous action space.

../_images/ACER_alg1.png ../_images/ACER_alg2.png

In continuous action space, it is impossible to enumerate all actions q value. So ACER uses sampled actions to replace the expectation.

../_images/ACER_alg3.png

Implementations

Here we show the ACER algorithm on the discrete action space. The default config is defined as follows:

class ding.policy.acer.ACERPolicy(cfg: dict, model: Optional[Union[type, torch.nn.modules.module.Module]] = None, enable_field: Optional[List[str]] = None)[source]
Overview:

Policy class of ACER algorithm.

Config:

ID

Symbol

Type

Default Value

Description

Other(Shape)

1

type

str

acer

RL policy register name, refer to
registry POLICY_REGISTRY
this arg is optional,
a placeholder

2

cuda

bool

False

Whether to use cuda for network

this arg can be diff-
erent from modes

3

on_policy

bool

False

Whether the RL algorithm is
on-policy or off-policy

4

trust_region

bool

True

Whether the RL algorithm use trust
region constraint


5

trust_region_value

float

1.0

maximum range of the trust region

6

unroll_len

int

32

trajectory length to calculate
Q retrace target

7

learn.update per_collect

int

4

How many updates(iterations) to
train after collector’s one
collection. Only
valid in serial training
this args can be vary
from envs. Bigger val

means more off-policy

8

c_clip_ratio

float

1.0

clip ratio of importance weights

Usually, we hope to compute everything as a batch to improve efficiency. This is done in policy._get_train_sample. Once we execute this function in collector, the length of samples will equal to unroll-len in config. For details, please refer to doc of ding.rl_utils.adder.

You can find more information in here

The whole code of ACER you can find here. Here we show some details of this algorithm.

First, we use the following functions to compute the retrace Q value.

def compute_q_retraces(q_values,v_pred,rewards,actions,weights,ratio,gamma=0.9):
    """
        Overview:
                Get Retrace Q value
            Arguments:
                - q_values (:obj:`torch.Tensor`): Q values
                - v_pred (:obj:`torch.Tensor`): V values
                - reward (:obj:`torch.Tensor`): reward values
                - actions (:obj:`torch.Tensor`): The actions in replay buffer
                - weights (:obj:`torch.Tensor`): setting padding postion
                - ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy
            Returns:
                - q_retraces (:obj:`torch.Tensor`):  retrace Q values
    """
    rewards = rewards.unsqueeze(-1)
    actions = actions.unsqueeze(-1)
    weights = weights.unsqueeze(-1)
    q_retraces = torch.zeros_like(v_pred)
    n_len = q_retraces.size()[0]
    tmp_retraces = v_pred[-1,...]
    q_retraces[-1,...] = v_pred[-1,...]
    q_gather = torch.zeros_like(v_pred)
    q_gather[0:-1,...] = q_values[0:-1,...].gather(-1,actions)
    ratio_gather = ratio.gather(-1,actions)
    for idx in reversed(range(n_len-1)):
        q_retraces[idx,...] = rewards[idx,...]+gamma*weights[idx,...]*tmp_retraces
        tmp_retraces = ratio_gather[idx,...].clamp(max=1.0)*(q_retraces[idx,...]-q_gather[idx,...])+v_pred[idx,...]
    return q_retraces

After that, we calculate the value of policy loss, it will calculate the actor loss with importance weights truncation and bias correction loss by the following function

def acer_policy_error(q_values,q_retraces,v_pred,target_pi,actions,ratio,c_clip_ratio=10.0):
    """
        Overview:
            Get ACER policy loss
        Arguments:
            - q_values (:obj:`torch.Tensor`): Q values
            - q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method)
            - v_pred (:obj:`torch.Tensor`): V values
            - target_pi (:obj:`torch.Tensor`): The new policy's probability
            - actions (:obj:`torch.Tensor`): The actions in replay buffer
            - ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy
            - c_clip_ratio (:obj:`float`): clip value for ratio
        Returns:
            - actor_loss (:obj:`torch.Tensor`): policy loss from q_retrace
            - bc_loss (:obj:`torch.Tensor`): bias correct policy loss
    """
    actions=actions.unsqueeze(-1)
    with torch.no_grad():
        advantage_retraces = q_retraces-v_pred
        advantage_native = q_values-v_pred
    actor_loss = ratio.gather(-1,actions).clamp(max=c_clip_ratio)*advantage_retraces*(target_pi.gather(-1,actions)+EPS).log()
    bc_loss = (1.0-c_clip_ratio/(ratio+EPS)).clamp(min=0.0)*target_pi.detach()*advantage_native*(target_pi+EPS).log()
    bc_loss=bc_loss.sum(-1).unsqueeze(-1)
    return actor_loss,bc_loss

Then, we execute backward operation towards target_pi. Moreover, we need to calculate the correction gradient in the trust region:

def acer_trust_region_update(actor_gradients,target_pi,avg_pi,trust_region_value):
    """
        Overview:
            calcuate gradient with trust region constrain
        Arguments:
            - actor_gradients (:obj:`list(torch.Tensor)`): gradients value's for different part
            - target_pi (:obj:`torch.Tensor`): The new policy's probability
            - avg_pi (:obj:`torch.Tensor`): The average policy's probability
            - trust_region_value (:obj:`float`): the range of trust region
        Returns:
            - update_gradients (:obj:`torch.Tensor`): gradients under trust region constraint
        """
    with torch.no_grad():
        KL_gradients = [(avg_pi/(target_pi+EPS))]
    update_gradients = []
    for actor_gradient,KL_gradient in zip(actor_gradients,KL_gradients):
        scale = actor_gradient.mul(KL_gradient).sum(-1).unsqueeze(-1)-trust_region_value
        scale = torch.div(scale,KL_gradient.mul(KL_gradient).sum(-1).unsqueeze(-1)).clamp(min=0.0)
        update_gradients.append(actor_gradient-scale*KL_gradient)
    return update_gradients

With the new gradients, we can continue to propagate backwardly and then update parameters accordingly.

Finally, we should calculate the mean squared loss for Q values to update Q-Network

def acer_value_error(q_values,q_retraces,actions):
    """
        Overview:
            Get ACER critic loss
        Arguments:
            - q_values (:obj:`torch.Tensor`): Q values
            - q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method)
            - actions (:obj:`torch.Tensor`): The actions in replay buffer
            - ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy
        Returns:
            - critic_loss (:obj:`torch.Tensor`): critic loss
        """
    actions=actions.unsqueeze(-1)
    critic_loss=0.5*(q_retraces-q_values.gather(-1,actions)).pow(2)
    return critic_loss

Reference

Ziyu Wang, Victor Bapst, Nicolas Heess, Volodymyr Mnih, Remi Munos, Koray Kavukcuoglu, Nando de Freitas: “Sample Efficient Actor-Critic with Experience Replay”, 2016; [https://arxiv.org/abs/1611.01224 arxiv:1611.01224].