提交 406ffdbf 编写于 作者: X Xiaochen Lian 提交者: Haonan

A simple replay buffer (#5)

* simple replay buffer and its test

* add error handling

* add test for deep copy
上级 7b0407b9
# tilde
*~
# VIM swap file
.*.swp
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class LastExpError(Exception):
"""
Raised when the last element or an element with non-zero game status is
sampled.
Attributes:
message(string): error message
"""
def __init__(self, idx, status):
self.message = 'The element at {}'.format(idx)
if status:
self.message += ' has game status: {}'.format(status)
else:
self.message += ' is the last experience of a game.'
def check_last_exp_error(is_last_exp, idx, game_status):
if is_last_exp:
raise LastExpError(idx, game_status)
def check_type_error(type1, type2):
if type1.__name__ != type2.__name__:
raise TypeError('{} expected, but {} given.'
.format(type1.__name__, type2.__name__))
def check_eq(v1, v2):
if v1 != v2:
raise ValueError('{} == {} does not hold'.format(v1, v2))
def check_neq(v1, v2):
if v1 == v2:
raise ValueError('{} != {} does not hold'.format(v1, v2))
def check_gt(v1, v2):
if v1 <= v2:
raise ValueError('{} > {} does not hold'.format(v1, v2))
def check_geq(v1, v2):
if v1 < v2:
raise ValueError('{} >= {} does not hold'.format(v1, v2))
def check_lt(v1, v2):
if v1 >= v2:
raise ValueError('{} < {} does not hold'.format(v1, v2))
def check_leq(v1, v2):
if v1 > v2:
raise ValueError('{} <= {} does not hold'.format(v1, v2))
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import random
from parl.common.error_handling import *
class Experience(object):
def __init__(self, sensor_inputs, states, actions, game_status):
check_type_error(list, type(sensor_inputs))
self.sensor_inputs = sensor_inputs # (observation, reward)
self.states = states # other states
self.actions = actions # actions taken
self.game_status = game_status # game status, e.g., max_steps or
# episode end reached
self.next_exp = None # copy of the next Experience
def set_next_exp(self, next_exp):
self.next_exp = copy.deepcopy(next_exp)
#TODO: write copy function
class Sample(object):
"""
A Sample represents one or a sequence of Experiences
"""
def __init__(self, i, n):
self.i = i # starting index of the first experience in the sample
self.n = n # length of the sequence
def __repr__(self):
return str(self.__class__) + ": " + str(self.__dict__)
class ReplayBuffer(object):
def __init__(self, capacity, exp_type=Experience):
"""
Create Replay buffer.
Args:
exp_type(object): Experience class used in the buffer.
capacity(int): Max number of experience to store in the buffer. When
the buffer overflows the old memories are dropped.
"""
check_gt(capacity, 1)
self.buffer = [] # a circular queue to store experiences
self.capacity = capacity # capacity of the buffer
self.last = -1 # the index of the last element in the buffer
self.exp_type = exp_type # Experience class used in the buffer
def __len__(self):
return len(self.buffer)
def buffer_end(self, i):
return i == self.last
def next_idx(self, i):
if self.buffer_end(i):
return -1
else:
return (i + 1) % self.capacity
def add(self, exp):
"""
Store one experience into the buffer.
Args:
exp(self.exp_type): the experience to store in the buffer.
"""
check_type_error(self.exp_type, type(exp))
# the next_exp field should be None at this point
check_eq(exp.next_exp, None)
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.last = (self.last + 1) % self.capacity
self.buffer[self.last] = copy.deepcopy(exp)
def sample(self, num_samples):
"""
Generate a batch of Samples. Each Sample represents a sequence of
Experiences (length>=1). And a sequence must not cross the boundary
between two games.
Args:
num_samples(int): Number of samples to generate.
Returns: A generator of Samples
"""
if len(self.buffer) <= 1:
yield []
for _ in xrange(num_samples):
while True:
idx = random.randint(0, len(self.buffer) - 1)
if not self.buffer_end(idx) and not self.buffer[
idx].game_status:
break
yield Sample(idx, 1)
def get_experiences(self, sample):
"""
Get Experiences from a Sample
Args:
sample(Sample): a Sample representing a sequence of Experiences
Return(list): a list of Experiences
"""
exps = []
p = sample.i
for _ in xrange(sample.n):
check_last_exp_error(
self.buffer_end(p) or self.buffer[p].game_status, p,
self.buffer[p].game_status)
# make a copy of the buffer element as e may be modified somewhere
e = copy.deepcopy(self.buffer[p])
p = self.next_idx(p)
e.set_next_exp(self.buffer[p])
exps.append(e)
return exps
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import unittest
from parl.common.error_handling import LastExpError
from parl.common.replay_buffer import Experience, Sample, ReplayBuffer
class ExperienceForTest(Experience):
def __init__(self, obs, reward, actions, new_field, status):
super(ExperienceForTest, self).__init__([obs, reward], [], actions,
status)
self.new_field = new_field
class TestReplayBuffer(unittest.TestCase):
def test_single_instance_replay_buffer(self):
capacity = 30
episode_len = 4
buf = ReplayBuffer(capacity, ExperienceForTest)
total = 0
expect_total = 0
for i in xrange(10 * capacity):
e = ExperienceForTest(
obs=np.zeros(10),
reward=i * 0.5,
actions=i,
new_field=np.ones(20),
status=(i + 1) % episode_len == 0)
buf.add(e)
# check the circular queue in the buffer
self.assertTrue(len(buf) == min(i + 1, capacity))
if (len(buf) < 2): # need at least two elements
continue
# should raise error when trying to pick up the last element
with self.assertRaises(LastExpError):
t = Sample(i % capacity, 1)
buf.get_experiences(t)
expect_total += len(buf)
# neither last element nor episode end should be picked up
for s in buf.sample(len(buf)):
try:
exps = buf.get_experiences(s)
total += 1
except LastExpError as err:
self.fail('test_single_instance_replay_buffer raised '
'LastExpError: ' + err.message)
# check the total number of elements added into the buffer
self.assertTrue(total == expect_total)
# detect incompatible Experience type
with self.assertRaises(TypeError):
e = Experience([np.zeros(10), i * 0.5], [], i, 0)
buf.add(e)
def test_deep_copy(self):
capacity = 5
buf = ReplayBuffer(capacity, Experience)
e0 = Experience(
sensor_inputs=[np.zeros(10), 0],
states=[],
actions=0,
game_status=0)
e1 = Experience([np.ones(10) * 2, 1], [], 0, 1)
buf.add(e0)
e0.sensor_inputs[0] += 1
buf.add(e0)
buf.add(e1)
s = Sample(0, 2)
exps = buf.get_experiences(s)
self.assertEqual(np.sum(exps[0].sensor_inputs[0] == 0), 10)
self.assertEqual(np.sum(exps[1].sensor_inputs[0] == 1), 10)
self.assertEqual(np.sum(exps[1].next_exp.sensor_inputs[0] == 2), 10)
exps[0].next_exp.sensor_inputs[0] += 3
self.assertEqual(np.sum(exps[1].sensor_inputs[0] == 1), 10)
exps[1].sensor_inputs[0] += 4
exps = buf.get_experiences(s)
self.assertEqual(np.sum(exps[0].next_exp.sensor_inputs[0] == 1), 10)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册