proportional_per.py 5.7 KB
Newer Older
Z
Zheyue Tan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np


class SumTree(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.elements = [None for _ in range(capacity)]
        self.tree = [0 for _ in range(2 * capacity - 1)]
        self._ptr = 0
        self._min = 10

    def full(self):
        return all(self.elements)  # no `None` in self.elements

    def add(self, item, priority):
        self.elements[self._ptr] = item
        tree_idx = self._ptr + self.capacity - 1
        self.update(tree_idx, priority)
        self._ptr = (self._ptr + 1) % self.capacity
        self._min = min(self._min, priority)

    def update(self, tree_idx, priority):
        diff = priority - self.tree[tree_idx]
        self.tree[tree_idx] = priority
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) >> 1
            self.tree[tree_idx] += diff
        self._min = min(self._min, priority)

    def retrieve(self, value):
        parent_idx = 0
        while True:
            left_child_idx = 2 * parent_idx + 1
            right_child_idx = left_child_idx + 1
            if left_child_idx >= len(self.tree):
                leaf_idx = parent_idx
                break
            else:
                if value <= self.tree[left_child_idx]:
                    parent_idx = left_child_idx
                else:
                    value -= self.tree[left_child_idx]
                    parent_idx = right_child_idx
        elem_idx = leaf_idx - self.capacity + 1
        priority = self.tree[leaf_idx]
        return self.elements[elem_idx], leaf_idx, priority

    def from_list(self, lst):
        assert len(lst) == self.capacity
        self.elements = list(lst)
        for i in range(self.capacity - 1, 2 * self.capacity - 1):
            self.update(i, 1.0)

    @property
    def total_p(self):
        return self.tree[0]


class ProportionalPER(object):
    """Proportional Prioritized Experience Replay.
    """

    def __init__(self,
                 alpha,
                 seg_num,
                 size=1e6,
                 eps=0.01,
                 init_mem=None,
                 framestack=4):
        self.alpha = alpha
        self.seg_num = seg_num
        self.size = int(size)
        self.elements = SumTree(self.size)
        if init_mem:
            self.elements.from_list(init_mem)
        self.framestack = framestack
        self._max_priority = 1.0
        self.eps = eps

    def _get_stacked_item(self, idx):
        """ For atari environment, we use a 4-frame-stack as input
        """
        obs, act, reward, next_obs, done = self.elements.elements[idx]
        stacked_obs = np.zeros((self.framestack, ) + obs.shape)
        stacked_obs[-1] = obs
        for i in range(self.framestack - 2, -1, -1):
            elem_idx = (self.size + idx + i - self.framestack + 1) % self.size
            obs, _, _, _, d = self.elements.elements[elem_idx]
            if d:
                break
            stacked_obs[i] = obs
        return (stacked_obs, act, reward, next_obs, done)

    def store(self, item, delta=None):
        assert len(item) == 5  # (s, a, r, s', terminal)
        if not delta:
            delta = self._max_priority
        assert delta >= 0
        ps = np.power(delta + self.eps, self.alpha)
        self.elements.add(item, ps)

    def update(self, indices, priorities):
        priorities = np.array(priorities) + self.eps
        priorities_alpha = np.power(priorities, self.alpha)
        for idx, priority in zip(indices, priorities_alpha):
            self.elements.update(idx, priority)
            self._max_priority = max(priority, self._max_priority)

    def sample_one(self):
        assert self.elements.full(), "The replay memory is not full!"
        sample_val = np.random.uniform(0, self.elements.total_p)
        item, tree_idx, _ = self.elements.retrieve(sample_val)
        return item, tree_idx

    def sample(self, beta=1):
        """ sample a batch of `seg_num` transitions
        Args:
            beta: float, degree of using importance sampling weights,
                0 - no corrections, 1 - full correction

        Return:
            items: sampled transitions
            indices: idxs of sampled items, used to update priorities later
            sample_weights: importance sampling weight
        """
        assert self.elements.full(), "The replay memory is not full!"
        seg_size = self.elements.total_p / self.seg_num
        seg_bound = [(seg_size * i, seg_size * (i + 1))
                     for i in range(self.seg_num)]
        items, indices, priorities = [], [], []
        for low, high in seg_bound:
            sample_val = np.random.uniform(low, high)
            _, tree_idx, priority = self.elements.retrieve(sample_val)
            elem_idx = tree_idx - self.elements.capacity + 1
            item = self._get_stacked_item(elem_idx)
            items.append(item)
            indices.append(tree_idx)
            priorities.append(priority)

        batch_probs = self.size * np.array(priorities) / self.elements.total_p
        min_prob = self.size * self.elements._min / self.elements.total_p
        sample_weights = np.power(batch_probs / min_prob, -beta)
        return np.array(items), np.array(indices), sample_weights