From 406ffdbfd4c47a5cc62df8ac9a708c7cf1b4e726 Mon Sep 17 00:00:00 2001 From: Xiaochen Lian Date: Sun, 27 May 2018 21:09:15 -0700 Subject: [PATCH] A simple replay buffer (#5) * simple replay buffer and its test * add error handling * add test for deep copy --- .gitignore | 3 + parl/common/__init__.py | 13 +++ parl/common/error_handling.py | 71 +++++++++++++ parl/common/replay_buffer.py | 136 ++++++++++++++++++++++++ parl/common/tests/test_replay_buffer.py | 93 ++++++++++++++++ 5 files changed, 316 insertions(+) create mode 100644 parl/common/__init__.py create mode 100644 parl/common/error_handling.py create mode 100644 parl/common/replay_buffer.py create mode 100644 parl/common/tests/test_replay_buffer.py diff --git a/.gitignore b/.gitignore index 4c3f8d6..9560657 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ # tilde *~ +# VIM swap file +.*.swp + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/parl/common/__init__.py b/parl/common/__init__.py new file mode 100644 index 0000000..eca2dce --- /dev/null +++ b/parl/common/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/parl/common/error_handling.py b/parl/common/error_handling.py new file mode 100644 index 0000000..86d3ad6 --- /dev/null +++ b/parl/common/error_handling.py @@ -0,0 +1,71 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class LastExpError(Exception): + """ + Raised when the last element or an element with non-zero game status is + sampled. + + Attributes: + message(string): error message + """ + + def __init__(self, idx, status): + self.message = 'The element at {}'.format(idx) + if status: + self.message += ' has game status: {}'.format(status) + else: + self.message += ' is the last experience of a game.' + + +def check_last_exp_error(is_last_exp, idx, game_status): + if is_last_exp: + raise LastExpError(idx, game_status) + + +def check_type_error(type1, type2): + if type1.__name__ != type2.__name__: + raise TypeError('{} expected, but {} given.' + .format(type1.__name__, type2.__name__)) + + +def check_eq(v1, v2): + if v1 != v2: + raise ValueError('{} == {} does not hold'.format(v1, v2)) + + +def check_neq(v1, v2): + if v1 == v2: + raise ValueError('{} != {} does not hold'.format(v1, v2)) + + +def check_gt(v1, v2): + if v1 <= v2: + raise ValueError('{} > {} does not hold'.format(v1, v2)) + + +def check_geq(v1, v2): + if v1 < v2: + raise ValueError('{} >= {} does not hold'.format(v1, v2)) + + +def check_lt(v1, v2): + if v1 >= v2: + raise ValueError('{} < {} does not hold'.format(v1, v2)) + + +def check_leq(v1, v2): + if v1 > v2: + raise ValueError('{} <= {} does not hold'.format(v1, v2)) diff --git a/parl/common/replay_buffer.py b/parl/common/replay_buffer.py new file mode 100644 index 0000000..3636e26 --- /dev/null +++ b/parl/common/replay_buffer.py @@ -0,0 +1,136 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import random +from parl.common.error_handling import * + + +class Experience(object): + def __init__(self, sensor_inputs, states, actions, game_status): + check_type_error(list, type(sensor_inputs)) + self.sensor_inputs = sensor_inputs # (observation, reward) + self.states = states # other states + self.actions = actions # actions taken + self.game_status = game_status # game status, e.g., max_steps or + # episode end reached + self.next_exp = None # copy of the next Experience + + def set_next_exp(self, next_exp): + self.next_exp = copy.deepcopy(next_exp) + + #TODO: write copy function + + +class Sample(object): + """ + A Sample represents one or a sequence of Experiences + """ + + def __init__(self, i, n): + self.i = i # starting index of the first experience in the sample + self.n = n # length of the sequence + + def __repr__(self): + return str(self.__class__) + ": " + str(self.__dict__) + + +class ReplayBuffer(object): + def __init__(self, capacity, exp_type=Experience): + """ + Create Replay buffer. + + Args: + exp_type(object): Experience class used in the buffer. + capacity(int): Max number of experience to store in the buffer. When + the buffer overflows the old memories are dropped. + """ + check_gt(capacity, 1) + self.buffer = [] # a circular queue to store experiences + self.capacity = capacity # capacity of the buffer + self.last = -1 # the index of the last element in the buffer + self.exp_type = exp_type # Experience class used in the buffer + + def __len__(self): + return len(self.buffer) + + def buffer_end(self, i): + return i == self.last + + def next_idx(self, i): + if self.buffer_end(i): + return -1 + else: + return (i + 1) % self.capacity + + def add(self, exp): + """ + Store one experience into the buffer. + + Args: + exp(self.exp_type): the experience to store in the buffer. + """ + check_type_error(self.exp_type, type(exp)) + # the next_exp field should be None at this point + check_eq(exp.next_exp, None) + + if len(self.buffer) < self.capacity: + self.buffer.append(None) + self.last = (self.last + 1) % self.capacity + self.buffer[self.last] = copy.deepcopy(exp) + + def sample(self, num_samples): + """ + Generate a batch of Samples. Each Sample represents a sequence of + Experiences (length>=1). And a sequence must not cross the boundary + between two games. + + Args: + num_samples(int): Number of samples to generate. + + Returns: A generator of Samples + """ + if len(self.buffer) <= 1: + yield [] + + for _ in xrange(num_samples): + while True: + idx = random.randint(0, len(self.buffer) - 1) + if not self.buffer_end(idx) and not self.buffer[ + idx].game_status: + break + yield Sample(idx, 1) + + def get_experiences(self, sample): + """ + Get Experiences from a Sample + + Args: + sample(Sample): a Sample representing a sequence of Experiences + + Return(list): a list of Experiences + """ + exps = [] + p = sample.i + for _ in xrange(sample.n): + check_last_exp_error( + self.buffer_end(p) or self.buffer[p].game_status, p, + self.buffer[p].game_status) + # make a copy of the buffer element as e may be modified somewhere + e = copy.deepcopy(self.buffer[p]) + p = self.next_idx(p) + e.set_next_exp(self.buffer[p]) + exps.append(e) + + return exps diff --git a/parl/common/tests/test_replay_buffer.py b/parl/common/tests/test_replay_buffer.py new file mode 100644 index 0000000..1750c36 --- /dev/null +++ b/parl/common/tests/test_replay_buffer.py @@ -0,0 +1,93 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import unittest +from parl.common.error_handling import LastExpError +from parl.common.replay_buffer import Experience, Sample, ReplayBuffer + + +class ExperienceForTest(Experience): + def __init__(self, obs, reward, actions, new_field, status): + super(ExperienceForTest, self).__init__([obs, reward], [], actions, + status) + self.new_field = new_field + + +class TestReplayBuffer(unittest.TestCase): + def test_single_instance_replay_buffer(self): + capacity = 30 + episode_len = 4 + buf = ReplayBuffer(capacity, ExperienceForTest) + total = 0 + expect_total = 0 + for i in xrange(10 * capacity): + e = ExperienceForTest( + obs=np.zeros(10), + reward=i * 0.5, + actions=i, + new_field=np.ones(20), + status=(i + 1) % episode_len == 0) + buf.add(e) + # check the circular queue in the buffer + self.assertTrue(len(buf) == min(i + 1, capacity)) + if (len(buf) < 2): # need at least two elements + continue + # should raise error when trying to pick up the last element + with self.assertRaises(LastExpError): + t = Sample(i % capacity, 1) + buf.get_experiences(t) + expect_total += len(buf) + # neither last element nor episode end should be picked up + for s in buf.sample(len(buf)): + try: + exps = buf.get_experiences(s) + total += 1 + except LastExpError as err: + self.fail('test_single_instance_replay_buffer raised ' + 'LastExpError: ' + err.message) + # check the total number of elements added into the buffer + self.assertTrue(total == expect_total) + # detect incompatible Experience type + with self.assertRaises(TypeError): + e = Experience([np.zeros(10), i * 0.5], [], i, 0) + buf.add(e) + + def test_deep_copy(self): + capacity = 5 + buf = ReplayBuffer(capacity, Experience) + e0 = Experience( + sensor_inputs=[np.zeros(10), 0], + states=[], + actions=0, + game_status=0) + e1 = Experience([np.ones(10) * 2, 1], [], 0, 1) + buf.add(e0) + e0.sensor_inputs[0] += 1 + buf.add(e0) + buf.add(e1) + s = Sample(0, 2) + exps = buf.get_experiences(s) + self.assertEqual(np.sum(exps[0].sensor_inputs[0] == 0), 10) + self.assertEqual(np.sum(exps[1].sensor_inputs[0] == 1), 10) + self.assertEqual(np.sum(exps[1].next_exp.sensor_inputs[0] == 2), 10) + exps[0].next_exp.sensor_inputs[0] += 3 + self.assertEqual(np.sum(exps[1].sensor_inputs[0] == 1), 10) + exps[1].sensor_inputs[0] += 4 + exps = buf.get_experiences(s) + self.assertEqual(np.sum(exps[0].next_exp.sensor_inputs[0] == 1), 10) + + +if __name__ == '__main__': + unittest.main() -- GitLab