# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from parl.utils import logger __all__ = ['ReplayMemory'] class ReplayMemory(object): def __init__(self, max_size, obs_dim, act_dim): self.max_size = int(max_size) self.obs_dim = obs_dim self.act_dim = act_dim self.obs = np.zeros((max_size, obs_dim), dtype='float32') self.action = np.zeros((max_size, act_dim), dtype='float32') self.reward = np.zeros((max_size, ), dtype='float32') self.terminal = np.zeros((max_size, ), dtype='bool') self.next_obs = np.zeros((max_size, obs_dim), dtype='float32') self._curr_size = 0 self._curr_pos = 0 def sample_batch(self, batch_size): # index mapping to avoid sampling saving example batch_idx = np.random.randint( self._curr_size - 300 - 1, size=batch_size) batch_idx = (self._curr_pos + 300 + batch_idx) % self._curr_size obs = self.obs[batch_idx] reward = self.reward[batch_idx] action = self.action[batch_idx] next_obs = self.next_obs[batch_idx] terminal = self.terminal[batch_idx] return obs, action, reward, next_obs, terminal def append(self, obs, act, reward, next_obs, terminal): if self._curr_size < self.max_size: self._curr_size += 1 self.obs[self._curr_pos] = obs self.action[self._curr_pos] = act self.reward[self._curr_pos] = reward self.next_obs[self._curr_pos] = next_obs self.terminal[self._curr_pos] = terminal self._curr_pos = (self._curr_pos + 1) % self.max_size def size(self): return self._curr_size def save(self, pathname): other = np.array([self._curr_size, self._curr_pos], dtype=np.int32) np.savez( pathname, obs=self.obs, action=self.action, reward=self.reward, terminal=self.terminal, next_obs=self.next_obs, other=other) def load(self, pathname): data = np.load(pathname) other = data['other'] if int(other[0]) > self.max_size: logger.warn('loading from a bigger size rpm!') self._curr_size = min(int(other[0]), self.max_size) self._curr_pos = min(int(other[1]), self.max_size - 1) self.obs[:self._curr_size] = data['obs'][:self._curr_size] self.action[:self._curr_size] = data['action'][:self._curr_size] self.reward[:self._curr_size] = data['reward'][:self._curr_size] self.terminal[:self._curr_size] = data['terminal'][:self._curr_size] self.next_obs[:self._curr_size] = data['next_obs'][:self._curr_size] logger.info("[load rpm]memory loade from {}".format(pathname))