SAC

Overview

Soft actor-critic (SAC) is an off-policy maximum entropy actor-critic algorithm, which is proposed in the 2018 paper Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor

Quick Facts

  1. SAC is implemented for environments with continuous action spaces.(i.e. MuJoCo, Pendulum, and LunarLander)

  2. SAC is an off-policy and model-free algorithm, combined with replay buffer start size for policy exploration.

  3. SAC is a actor-critic RL algorithm, which optimizes actor network and critic network, respectively,

  4. SAC is implemented for multi-continuous action space.

Key Equations or Key Graphs

SAC considers a more general maximum entropy objective, which favors stochastic policies by augmenting the objective with the expected entropy of the policy:

\[J(\pi)=\sum_{t=0}^{T} \mathbb{E}_{\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right) \sim \rho_{\pi}}\left[r\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)+\alpha \mathcal{H}\left(\pi\left(\cdot \mid \mathbf{s}_{t}\right)\right)\right].\]

The temperature parameters \(\alpha > 0\) controls the stochasticity of the optimal policy. Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor considers a parameterized state value function, soft Q-function, and a tractable policy. Specifically, the value functions are modeled as expressive neural networks, and the policy as a Gaussian with mean and covariance given by neural networks. In particular, SAC applys the reparameterization trick instead of directly minimizing the expected KL-divergence for policy parameters as

\[J_{\pi}(\phi)=\mathbb{E}_{\mathbf{s}_{t} \sim \mathcal{D}, \epsilon_{t} \sim \mathcal{N}}\left[\log \pi_{\phi}\left(f_{\phi}\left(\epsilon_{t} ; \mathbf{s}_{t}\right) \mid \mathbf{s}_{t}\right)-Q_{\theta}\left(\mathbf{s}_{t}, f_{\phi}\left(\epsilon_{t} ; \mathbf{s}_{t}\right)\right)\right]\]

Note

Compared with the vanilla version modeling state value function and soft Q-function, our implementation contains two versions. One is modeling state value function and soft Q-function, the other is only modeling soft Q-function through double network.

Pseudocode

../_images/SAC-algorithm.png
\[ \begin{align}\begin{aligned}:nowrap:\\\begin{algorithm}[H] \caption{Soft Actor-Critic} \label{alg1} \begin{algorithmic}[1] \STATE Input: initial policy parameters $\theta$, Q-function parameters $\phi_1$, $\phi_2$, empty replay buffer $\mathcal{D}$ \STATE Set target parameters equal to main parameters $\phi_{\text{targ},1} \leftarrow \phi_1$, $\phi_{\text{targ},2} \leftarrow \phi_2$ \REPEAT \STATE Observe state $s$ and select action $a \sim \pi_{\theta}(\cdot|s)$ \STATE Execute $a$ in the environment \STATE Observe next state $s'$, reward $r$, and done signal $d$ to indicate whether $s'$ is terminal \STATE Store $(s,a,r,s',d)$ in replay buffer $\mathcal{D}$ \STATE If $s'$ is terminal, reset environment state. \IF{it's time to update} \FOR{$j$ in range(however many updates)} \STATE Randomly sample a batch of transitions, $B = \{ (s,a,r,s',d) \}$ from $\mathcal{D}$ \STATE Compute targets for the Q functions: \begin{align*} y (r,s',d) &= r + \gamma (1-d) \left(\min_{i=1,2} Q_{\phi_{\text{targ}, i}} (s', \tilde{a}') - \alpha \log \pi_{\theta}(\tilde{a}'|s')\right), && \tilde{a}' \sim \pi_{\theta}(\cdot|s') \end{align*} \STATE Update Q-functions by one step of gradient descent using \begin{align*} & \nabla_{\phi_i} \frac{1}{|B|}\sum_{(s,a,r,s',d) \in B} \left( Q_{\phi_i}(s,a) - y(r,s',d) \right)^2 && \text{for } i=1,2 \end{align*} \STATE Update policy by one step of gradient ascent using \begin{equation*} \nabla_{\theta} \frac{1}{|B|}\sum_{s \in B} \Big(\min_{i=1,2} Q_{\phi_i}(s, \tilde{a}_{\theta}(s)) - \alpha \log \pi_{\theta} \left(\left. \tilde{a}_{\theta}(s) \right| s\right) \Big), \end{equation*} where $\tilde{a}_{\theta}(s)$ is a sample from $\pi_{\theta}(\cdot|s)$ which is differentiable wrt $\theta$ via the reparametrization trick. \STATE Update target networks with \begin{align*} \phi_{\text{targ},i} &\leftarrow \rho \phi_{\text{targ}, i} + (1-\rho) \phi_i && \text{for } i=1,2 \end{align*} \ENDFOR \ENDIF \UNTIL{convergence} \end{algorithmic} \end{algorithm}\end{aligned}\end{align} \]

Note

Compared with the vanilla version, we only optimize q network and actor network in our second implementation version.

Extensions

SAC can be combined with:

Implementation

The default config is defined as follows:

class ding.policy.sac.SACPolicy(cfg: dict, model: Optional[Union[type, torch.nn.modules.module.Module]] = None, enable_field: Optional[List[str]] = None)[source]
Overview:

Policy class of SAC algorithm.

https://arxiv.org/pdf/1801.01290.pdf

Config:

ID

Symbol

Type

Default Value

Description

Other(Shape)

1

type

str

td3

RL policy register name, refer
to registry POLICY_REGISTRY
this arg is optional,
a placeholder

2

cuda

bool

True

Whether to use cuda for network

3

random_
collect_size

int

10000

Number of randomly collected
training samples in replay
buffer when training starts.
Default to 10000 for
SAC, 25000 for DDPG/
TD3.

4

model.policy_
embedding_size

int

256

Linear layer size for policy
network.


5

model.soft_q_
embedding_size

int

256

Linear layer size for soft q
network.


6

model.value_
embedding_size

int

256

Linear layer size for value
network.

Defalut to None when
model.value_network
is False.

7

learn.learning
_rate_q

float

3e-4

Learning rate for soft q
network.

Defalut to 1e-3, when
model.value_network
is True.

8

learn.learning
_rate_policy

float

3e-4

Learning rate for policy
network.

Defalut to 1e-3, when
model.value_network
is True.

9

learn.learning
_rate_value

float

3e-4

Learning rate for policy
network.

Defalut to None when
model.value_network
is False.

10

learn.alpha



float

0.2

Entropy regularization
coefficient.


alpha is initiali-
zation for auto
alpha, when
auto_alpha is True

11

learn.repara_
meterization

bool

True

Determine whether to use
reparameterization trick.


12

learn.
auto_alpha



bool

False

Determine whether to use
auto temperature parameter
alpha.


Temperature parameter
determines the
relative importance
of the entropy term
against the reward.

13

learn.-
ignore_done

bool

False

Determine whether to ignore
done flag.
Use ignore_done only
in halfcheetah env.

14

learn.-
target_theta


float

0.005

Used for soft update of the
target network.


aka. Interpolation
factor in polyak aver
aging for target
networks.

We take the second version implementation(only predict soft Q function) as an example to introduce SAC algorithm:

SAC model includes soft Q network and Policy network:

Initialization Model.

# build network
self._policy_net = PolicyNet(self._obs_shape, self._act_shape, self._policy_embedding_size)

self._twin_q = twin_q
if not self._twin_q:
    self._soft_q_net = SoftQNet(self._obs_shape, self._act_shape, self._soft_q_embedding_size)
else:
    self._soft_q_net = nn.ModuleList()
    for i in range(2):
        self._soft_q_net.append(SoftQNet(self._obs_shape, self._act_shape, self._soft_q_embedding_size))

Soft Q prediction from soft Q network:
def compute_critic_q(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    action = inputs['action']
    if len(action.shape) == 1:
        action = action.unsqueeze(1)
    state_action_input = torch.cat([inputs['obs'], action], dim=1)
    q_value = self._soft_q_net_forward(state_action_input)
    return {'q_value': q_value}

Action prediction from policy network:

def compute_actor(self, obs: torch.Tensor, deterministic_eval=False, epsilon=1e-6) -> Dict[str, torch.Tensor]:
    mean, log_std = self._policy_net_forward(obs)
    std = log_std.exp()

    # unbounded Gaussian as the action distribution.
    dist = Independent(Normal(mean, std), 1)
    # for reparameterization trick (mean + std * N(0,1))
    if deterministic_eval:
        x = mean
    else:
        x = dist.rsample()
    y = torch.tanh(x)
    action = y

    # epsilon is used to avoid log of zero/negative number.
    y = 1 - y.pow(2) + epsilon
    log_prob = dist.log_prob(x).unsqueeze(-1)
    log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)

    return {'mean': mean, 'log_std': log_std, 'action': action, 'log_prob': log_prob}

Note

SAC applys an invertible squashing function to the Gaussian samples, and employ the change of variables formula to compute the likelihoods of the bounded actions. Specifically, we use unbounded Gaussian as the action distribution through Independent(Normal(mean, std), 1), which creates a diagonal Normal distribution with the same shape as a Multivariate Normal distribution. This is equal to log_prob.sum(axis=-1). Then, the action is squashed by \(\tanh(\text{mean})\), and the log-likelihood of action has a simple form \(\log \pi(\mathbf{a} \mid \mathbf{s})=\log \mu(\mathbf{u} \mid \mathbf{s})-\sum_{i=1}^{D} \log \left(1-\tanh ^{2}\left(u_{i}\right)\right)\). In particular, the std in SAC is predicted from observation, which is different from PPO(learnable parameter) and TD3(heuristic parameter).

Entropy-Regularized Reinforcement Learning as follows:

Entropy in target q value.

# target q value. SARSA: first predict next action, then calculate next q value
with torch.no_grad():
    next_data = {'obs': next_obs}
    next_action = self._learn_model.forward(data['obs'], mode='compute_actor', deterministic_eval=False)
    next_data['action'] = next_action['action']
    next_data['log_prob'] = next_action['log_prob']
    # the value of a policy according to the maximum entropy objective
    if self._twin_q:
        # find min one as target q value
        target_q_value = torch.min(target_q_value[0],
                                   target_q_value[1]) - self._alpha * next_data['log_prob'].squeeze(-1)
    else:
        target_q_value = target_q_value - self._alpha * next_data['log_prob'].squeeze(-1)

Soft Q value network update.

# =================
# q network
# =================
# compute q loss
if self._twin_q:
    q_data0 = v_1step_td_data(q_value[0], target_value, reward, done, data['weight'])
    loss_dict['q_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma)
    q_data1 = v_1step_td_data(q_value[1], target_value, reward, done, data['weight'])
    loss_dict['q_twin_loss'], td_error_per_sample1 = v_1step_td_error(q_data1, self._gamma)
    td_error_per_sample = (td_error_per_sample0 + td_error_per_sample1) / 2
else:
    q_data = v_1step_td_data(q_value, target_value, reward, done, data['weight'])
    loss_dict['q_loss'], td_error_per_sample = v_1step_td_error(q_data, self._gamma)

# update q network
self._optimizer_q.zero_grad()
loss_dict['q_loss'].backward()
if self._twin_q:
    loss_dict['q_twin_loss'].backward()
self._optimizer_q.step()

Entropy in policy loss.

# compute policy loss
if not self._reparameterization:
    target_log_policy = new_q_value - v_value
    policy_loss = (log_prob * (log_prob - target_log_policy.unsqueeze(-1))).mean()
else:
    policy_loss = (self._alpha * log_prob - new_q_value.unsqueeze(-1)).mean()

# update policy network
self._optimizer_policy.zero_grad()
loss_dict['policy_loss'].backward()
self._optimizer_policy.step()

Note

We implement reparameterization trick trough \((\text{mean} + \text{std} * \mathcal{N}(0,1))\). In particular, the gradient back propagation for sigma is through log_prob in policy loss.

Auto alpha strategy

Alpha initialization through log action shape.

if self._cfg.learn.is_auto_alpha:
    self._target_entropy = -np.prod(self._cfg.model.action_shape)
    self._log_alpha = torch.log(torch.tensor([self._cfg.learn.alpha]))
    self._log_alpha = self._log_alpha.to(device='cuda' if self._cuda else 'cpu').requires_grad_()
    self._alpha_optim = torch.optim.Adam([self._log_alpha], lr=self._cfg.learn.learning_rate_alpha)
    self._is_auto_alpha = True
    assert self._log_alpha.shape == torch.Size([1]) and self._log_alpha.requires_grad
    self._alpha = self._log_alpha.detach().exp()

Alpha update.

# compute alpha loss
if self._is_auto_alpha:
    log_prob = log_prob.detach() + self._target_entropy
    loss_dict['alpha_loss'] = -(self._log_alpha * log_prob).mean()

    self._alpha_optim.zero_grad()
    loss_dict['alpha_loss'].backward()
    self._alpha_optim.step()
    self._alpha = self._log_alpha.detach().exp()

The Benchmark result of SAC implemented in DI-engine is shown in Benchmark

Other Public Implementations