未验证 提交 a7ac16df 编写于 作者: Z zenghsh3 提交者: GitHub

fix bug of rpm (#359)

* fix bug of rpm

* fix bug of rpm

* soft link
上级 04a16723
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from parl.utils import logger
__all__ = ['ReplayMemory']
class ReplayMemory(object):
def __init__(self, max_size, obs_dim, act_dim):
self.max_size = int(max_size)
self.obs_dim = obs_dim
self.act_dim = act_dim
self.obs = np.zeros((max_size, obs_dim), dtype='float32')
self.action = np.zeros((max_size, act_dim), dtype='float32')
self.reward = np.zeros((max_size, ), dtype='float32')
self.terminal = np.zeros((max_size, ), dtype='bool')
self.next_obs = np.zeros((max_size, obs_dim), dtype='float32')
self._curr_size = 0
self._curr_pos = 0
def sample_batch(self, batch_size):
batch_idx = np.random.randint(
self._curr_size - 300 - 1, size=batch_size)
obs = self.obs[batch_idx]
reward = self.reward[batch_idx]
action = self.action[batch_idx]
next_obs = self.next_obs[batch_idx]
terminal = self.terminal[batch_idx]
return obs, action, reward, next_obs, terminal
def make_index(self, batch_size):
batch_idx = np.random.randint(
self._curr_size - 300 - 1, size=batch_size)
return batch_idx
def sample_batch_by_index(self, batch_idx):
obs = self.obs[batch_idx]
reward = self.reward[batch_idx]
action = self.action[batch_idx]
next_obs = self.next_obs[batch_idx]
terminal = self.terminal[batch_idx]
return obs, action, reward, next_obs, terminal
def append(self, obs, act, reward, next_obs, terminal):
if self._curr_size < self.max_size:
self._curr_size += 1
self.obs[self._curr_pos] = obs
self.action[self._curr_pos] = act
self.reward[self._curr_pos] = reward
self.next_obs[self._curr_pos] = next_obs
self.terminal[self._curr_pos] = terminal
self._curr_pos = (self._curr_pos + 1) % self.max_size
def size(self):
return self._curr_size
def save(self, pathname):
other = np.array([self._curr_size, self._curr_pos], dtype=np.int32)
np.savez(
pathname,
obs=self.obs,
action=self.action,
reward=self.reward,
terminal=self.terminal,
next_obs=self.next_obs,
other=other)
def load(self, pathname):
data = np.load(pathname)
other = data['other']
if int(other[0]) > self.max_size:
logger.warn('loading from a bigger size rpm!')
self._curr_size = min(int(other[0]), self.max_size)
self._curr_pos = min(int(other[1]), self.max_size - 1)
self.obs[:self._curr_size] = data['obs'][:self._curr_size]
self.action[:self._curr_size] = data['action'][:self._curr_size]
self.reward[:self._curr_size] = data['reward'][:self._curr_size]
self.terminal[:self._curr_size] = data['terminal'][:self._curr_size]
self.next_obs[:self._curr_size] = data['next_obs'][:self._curr_size]
logger.info("[load rpm]memory loade from {}".format(pathname))
......@@ -28,7 +28,8 @@ from concurrent import futures
from multi_head_ddpg import MultiHeadDDPG
from opensim_agent import OpenSimAgent
from opensim_model import OpenSimModel
from parl.utils import logger, ReplayMemory
from parl.utils import logger
from replay_memory import ReplayMemory
from utils import calc_indicators, ScalarsManager, TransitionExperience
ACT_DIM = 19
......
......@@ -22,7 +22,7 @@ import numpy as np
from actor import Actor
from opensim_model import OpenSimModel
from opensim_agent import OpenSimAgent
from parl.utils import logger, ReplayMemory, summary, get_gpu_count
from parl.utils import logger, summary, get_gpu_count
from parl.utils.window_stat import WindowStat
from parl.remote.client import get_global_client
from parl.utils import machine_info
......
../NeurIPS2018-AI-for-Prosthetics-Challenge/replay_memory.py
\ No newline at end of file
......@@ -22,7 +22,8 @@ import numpy as np
from actor import Actor
from opensim_model import OpenSimModel
from opensim_agent import OpenSimAgent
from parl.utils import logger, ReplayMemory, summary, get_gpu_count
from parl.utils import logger, summary, get_gpu_count
from replay_memory import ReplayMemory
from parl.utils.window_stat import WindowStat
from parl.remote.client import get_global_client
from parl.utils import machine_info
......
......@@ -34,8 +34,7 @@ class ReplayMemory(object):
self._curr_pos = 0
def sample_batch(self, batch_size):
batch_idx = np.random.randint(
self._curr_size - 300 - 1, size=batch_size)
batch_idx = np.random.randint(self._curr_size, size=batch_size)
obs = self.obs[batch_idx]
reward = self.reward[batch_idx]
......@@ -45,8 +44,7 @@ class ReplayMemory(object):
return obs, action, reward, next_obs, terminal
def make_index(self, batch_size):
batch_idx = np.random.randint(
self._curr_size - 300 - 1, size=batch_size)
batch_idx = np.random.randint(self._curr_size, size=batch_size)
return batch_idx
def sample_batch_by_index(self, batch_idx):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册