Prioritized Experience Replay Buffer

This implements paper Prioritized experience replay, using a binary segment tree.

Open In Colab View Run

16import random
17
18import numpy as np

Buffer for Prioritized Experience Replay

Prioritized experience replay samples important transitions more frequently. The transitions are prioritized by the Temporal Difference error (td error), $\delta$.

We sample transition $i$ with probability, where $\alpha$ is a hyper-parameter that determines how much prioritization is used, with $\alpha = 0$ corresponding to uniform case. $p_i$ is the priority.

We use proportional prioritization $p_i = |\delta_i| + \epsilon$ where $\delta_i$ is the temporal difference for transition $i$.

We correct the bias introduced by prioritized replay using importance-sampling (IS) weights in the loss function. This fully compensates when $\beta = 1$. We normalize weights by $\frac{1}{\max_i w_i}$ for stability. Unbiased nature is most important towards the convergence at end of training. Therefore we increase $\beta$ towards end of training.

Binary Segment Tree

We use a binary segment tree to efficiently calculate $\sum_k^i p_k^\alpha$, the cumulative probability, which is needed to sample. We also use a binary segment tree to find $\min p_i^\alpha$, which is needed for $\frac{1}{\max_i w_i}$. We can also use a min-heap for this. Binary Segment Tree lets us calculate these in $\mathcal{O}(\log n)$ time, which is way more efficient that the naive $\mathcal{O}(n)$ approach.

This is how a binary segment tree works for sum; it is similar for minimum. Let $x_i$ be the list of $N$ values we want to represent. Let $b_{i,j}$ be the $j^{\mathop{th}}$ node of the $i^{\mathop{th}}$ row in the binary tree. That is two children of node $b_{i,j}$ are $b_{i+1,2j}$ and $b_{i+1,2j + 1}$.

The leaf nodes on row $D = \left\lceil {1 + \log_2 N} \right\rceil$ will have values of $x$. Every node keeps the sum of the two child nodes. That is, the root node keeps the sum of the entire array of values. The left and right children of the root node keep the sum of the first half of the array and the sum of the second half of the array, respectively. And so on…

Number of nodes in row $i$, This is equal to the sum of nodes in all rows above $i$. So we can use a single array $a$ to store the tree, where,

Then child nodes of $a_i$ are $a_{2i}$ and $a_{2i + 1}$. That is,

This way of maintaining binary trees is very easy to program. Note that we are indexing starting from 1.

We use the same structure to compute the minimum.

21class ReplayBuffer:

Initialize

91    def __init__(self, capacity, alpha):

We use a power of $2$ for capacity because it simplifies the code and debugging

96        self.capacity = capacity

$\alpha$

98        self.alpha = alpha

Maintain segment binary trees to take sum and find minimum over a range

101        self.priority_sum = [0 for _ in range(2 * self.capacity)]
102        self.priority_min = [float('inf') for _ in range(2 * self.capacity)]

Current max priority, $p$, to be assigned to new transitions

105        self.max_priority = 1.

Arrays for buffer

108        self.data = {
109            'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
110            'action': np.zeros(shape=capacity, dtype=np.int32),
111            'reward': np.zeros(shape=capacity, dtype=np.float32),
112            'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
113            'done': np.zeros(shape=capacity, dtype=np.bool)
114        }

We use cyclic buffers to store data, and next_idx keeps the index of the next empty slot

117        self.next_idx = 0

Size of the buffer

120        self.size = 0

Add sample to queue

122    def add(self, obs, action, reward, next_obs, done):

Get next available slot

128        idx = self.next_idx

store in the queue

131        self.data['obs'][idx] = obs
132        self.data['action'][idx] = action
133        self.data['reward'][idx] = reward
134        self.data['next_obs'][idx] = next_obs
135        self.data['done'][idx] = done

Increment next available slot

138        self.next_idx = (idx + 1) % self.capacity

Calculate the size

140        self.size = min(self.capacity, self.size + 1)

$p_i^\alpha$, new samples get max_priority

143        priority_alpha = self.max_priority ** self.alpha

Update the two segment trees for sum and minimum

145        self._set_priority_min(idx, priority_alpha)
146        self._set_priority_sum(idx, priority_alpha)

Set priority in binary segment tree for minimum

148    def _set_priority_min(self, idx, priority_alpha):

Leaf of the binary tree

154        idx += self.capacity
155        self.priority_min[idx] = priority_alpha

Update tree, by traversing along ancestors. Continue until the root of the tree.

159        while idx >= 2:

Get the index of the parent node

161            idx //= 2

Value of the parent node is the minimum of it’s two children

163            self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])

Set priority in binary segment tree for sum

165    def _set_priority_sum(self, idx, priority):

Leaf of the binary tree

171        idx += self.capacity

Set the priority at the leaf

173        self.priority_sum[idx] = priority

Update tree, by traversing along ancestors. Continue until the root of the tree.

177        while idx >= 2:

Get the index of the parent node

179            idx //= 2

Value of the parent node is the sum of it’s two children

181            self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]

$\sum_k p_k^\alpha$

183    def _sum(self):

The root node keeps the sum of all values

189        return self.priority_sum[1]

$\min_k p_k^\alpha$

191    def _min(self):

The root node keeps the minimum of all values

197        return self.priority_min[1]

Find largest $i$ such that $\sum_{k=1}^{i} p_k^\alpha \le P$

199    def find_prefix_sum_idx(self, prefix_sum):

Start from the root

205        idx = 1
206        while idx < self.capacity:

If the sum of the left branch is higher than required sum

208            if self.priority_sum[idx * 2] > prefix_sum:

Go to left branch of the tree

210                idx = 2 * idx
211            else:

Otherwise go to right branch and reduce the sum of left branch from required sum

214                prefix_sum -= self.priority_sum[idx * 2]
215                idx = 2 * idx + 1

We are at the leaf node. Subtract the capacity by the index in the tree to get the index of actual value

219        return idx - self.capacity

Sample from buffer

221    def sample(self, batch_size, beta):

Initialize samples

227        samples = {
228            'weights': np.zeros(shape=batch_size, dtype=np.float32),
229            'indexes': np.zeros(shape=batch_size, dtype=np.int32)
230        }

Get sample indexes

233        for i in range(batch_size):
234            p = random.random() * self._sum()
235            idx = self.find_prefix_sum_idx(p)
236            samples['indexes'][i] = idx

$\min_i P(i) = \frac{\min_i p_i^\alpha}{\sum_k p_k^\alpha}$

239        prob_min = self._min() / self._sum()

$\max_i w_i = \bigg(\frac{1}{N} \frac{1}{\min_i P(i)}\bigg)^\beta$

241        max_weight = (prob_min * self.size) ** (-beta)
242
243        for i in range(batch_size):
244            idx = samples['indexes'][i]

$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$

246            prob = self.priority_sum[idx + self.capacity] / self._sum()

$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$

248            weight = (prob * self.size) ** (-beta)

Normalize by $\frac{1}{\max_i w_i}$, which also cancels off the $\frac{1}{N}$ term

251            samples['weights'][i] = weight / max_weight

Get samples data

254        for k, v in self.data.items():
255            samples[k] = v[samples['indexes']]
256
257        return samples

Update priorities

259    def update_priorities(self, indexes, priorities):
264        for idx, priority in zip(indexes, priorities):

Set current max priority

266            self.max_priority = max(self.max_priority, priority)

Calculate $p_i^\alpha$

269            priority_alpha = priority ** self.alpha

Update the trees

271            self._set_priority_min(idx, priority_alpha)
272            self._set_priority_sum(idx, priority_alpha)

Whether the buffer is full

274    def is_full(self):
278        return self.capacity == self.size