未验证 提交 de21118e 编写于 作者: L LI Yunxiang 提交者: GitHub

state to obs (#256)

* state to obs

* yapf & update softlink in offline-q-learning
上级 6fa2d081
...@@ -16,16 +16,16 @@ import numpy as np ...@@ -16,16 +16,16 @@ import numpy as np
import copy import copy
from collections import deque, namedtuple from collections import deque, namedtuple
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'isOver']) Experience = namedtuple('Experience', ['obs', 'action', 'reward', 'isOver'])
class ReplayMemory(object): class ReplayMemory(object):
def __init__(self, max_size, state_shape, context_len): def __init__(self, max_size, obs_shape, context_len):
self.max_size = int(max_size) self.max_size = int(max_size)
self.state_shape = state_shape self.obs_shape = obs_shape
self.context_len = int(context_len) self.context_len = int(context_len)
self.state = np.zeros((self.max_size, ) + state_shape, dtype='uint8') self.obs = np.zeros((self.max_size, ) + obs_shape, dtype='uint8')
self.action = np.zeros((self.max_size, ), dtype='int32') self.action = np.zeros((self.max_size, ), dtype='int32')
self.reward = np.zeros((self.max_size, ), dtype='float32') self.reward = np.zeros((self.max_size, ), dtype='float32')
self.isOver = np.zeros((self.max_size, ), dtype='bool') self.isOver = np.zeros((self.max_size, ), dtype='bool')
...@@ -48,42 +48,41 @@ class ReplayMemory(object): ...@@ -48,42 +48,41 @@ class ReplayMemory(object):
else: else:
self._context.append(exp) self._context.append(exp)
def recent_state(self): def recent_obs(self):
""" maintain recent state for training""" """ maintain recent obs for training"""
lst = list(self._context) lst = list(self._context)
states = [np.zeros(self.state_shape, dtype='uint8')] * \ obs = [np.zeros(self.obs_shape, dtype='uint8')] * \
(self._context.maxlen - len(lst)) (self._context.maxlen - len(lst))
states.extend([k.state for k in lst]) obs.extend([k.obs for k in lst])
return states return obs
def sample(self, idx): def sample(self, idx):
""" return state, action, reward, isOver, """ return obs, action, reward, isOver,
note that some frames in state may be generated from last episode, note that some frames in obs may be generated from last episode,
they should be removed from state they should be removed from obs
""" """
state = np.zeros( obs = np.zeros(
(self.context_len + 1, ) + self.state_shape, dtype=np.uint8) (self.context_len + 1, ) + self.obs_shape, dtype=np.uint8)
state_idx = np.arange(idx, obs_idx = np.arange(idx, idx + self.context_len + 1) % self._curr_size
idx + self.context_len + 1) % self._curr_size
# confirm that no frame was generated from last episode # confirm that no frame was generated from last episode
has_last_episode = False has_last_episode = False
for k in range(self.context_len - 2, -1, -1): for k in range(self.context_len - 2, -1, -1):
to_check_idx = state_idx[k] to_check_idx = obs_idx[k]
if self.isOver[to_check_idx]: if self.isOver[to_check_idx]:
has_last_episode = True has_last_episode = True
state_idx = state_idx[k + 1:] obs_idx = obs_idx[k + 1:]
state[k + 1:] = self.state[state_idx] obs[k + 1:] = self.obs[obs_idx]
break break
if not has_last_episode: if not has_last_episode:
state = self.state[state_idx] obs = self.obs[obs_idx]
real_idx = (idx + self.context_len - 1) % self._curr_size real_idx = (idx + self.context_len - 1) % self._curr_size
action = self.action[real_idx] action = self.action[real_idx]
reward = self.reward[real_idx] reward = self.reward[real_idx]
isOver = self.isOver[real_idx] isOver = self.isOver[real_idx]
return state, reward, action, isOver return obs, reward, action, isOver
def __len__(self): def __len__(self):
return self._curr_size return self._curr_size
...@@ -92,7 +91,7 @@ class ReplayMemory(object): ...@@ -92,7 +91,7 @@ class ReplayMemory(object):
return self._curr_size return self._curr_size
def _assign(self, pos, exp): def _assign(self, pos, exp):
self.state[pos] = exp.state self.obs[pos] = exp.obs
self.reward[pos] = exp.reward self.reward[pos] = exp.reward
self.action[pos] = exp.action self.action[pos] = exp.action
self.isOver[pos] = exp.isOver self.isOver[pos] = exp.isOver
...@@ -107,8 +106,8 @@ class ReplayMemory(object): ...@@ -107,8 +106,8 @@ class ReplayMemory(object):
return self._process_batch(batch_exp) return self._process_batch(batch_exp)
def _process_batch(self, batch_exp): def _process_batch(self, batch_exp):
state = np.asarray([e[0] for e in batch_exp], dtype='uint8') obs = np.asarray([e[0] for e in batch_exp], dtype='uint8')
reward = np.asarray([e[1] for e in batch_exp], dtype='float32') reward = np.asarray([e[1] for e in batch_exp], dtype='float32')
action = np.asarray([e[2] for e in batch_exp], dtype='int8') action = np.asarray([e[2] for e in batch_exp], dtype='int8')
isOver = np.asarray([e[3] for e in batch_exp], dtype='bool') isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
return [state, action, reward, isOver] return [obs, action, reward, isOver]
...@@ -26,7 +26,7 @@ from parl.utils import tensorboard, logger ...@@ -26,7 +26,7 @@ from parl.utils import tensorboard, logger
from parl.algorithms import DQN, DDQN from parl.algorithms import DQN, DDQN
from agent import AtariAgent from agent import AtariAgent
from atari_wrapper import FireResetEnv, FrameStack, LimitLength, MapState from atari_wrapper import FireResetEnv, FrameStack, LimitLength
from model import AtariModel from model import AtariModel
from replay_memory import ReplayMemory, Experience from replay_memory import ReplayMemory, Experience
from utils import get_player from utils import get_player
...@@ -43,57 +43,57 @@ GAMMA = 0.99 ...@@ -43,57 +43,57 @@ GAMMA = 0.99
def run_train_episode(env, agent, rpm): def run_train_episode(env, agent, rpm):
total_reward = 0 total_reward = 0
all_cost = [] all_cost = []
state = env.reset() obs = env.reset()
steps = 0 steps = 0
while True: while True:
steps += 1 steps += 1
context = rpm.recent_state() context = rpm.recent_obs()
context.append(state) context.append(obs)
context = np.stack(context, axis=0) context = np.stack(context, axis=0)
action = agent.sample(context) action = agent.sample(context)
next_state, reward, isOver, _ = env.step(action) next_obs, reward, isOver, _ = env.step(action)
rpm.append(Experience(state, action, reward, isOver)) rpm.append(Experience(obs, action, reward, isOver))
if rpm.size() > MEMORY_WARMUP_SIZE: if rpm.size() > MEMORY_WARMUP_SIZE:
if steps % UPDATE_FREQ == 0: if steps % UPDATE_FREQ == 0:
batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_batch( batch_all_obs, batch_action, batch_reward, batch_isOver = rpm.sample_batch(
args.batch_size) args.batch_size)
batch_state = batch_all_state[:, :CONTEXT_LEN, :, :] batch_obs = batch_all_obs[:, :CONTEXT_LEN, :, :]
batch_next_state = batch_all_state[:, 1:, :, :] batch_next_obs = batch_all_obs[:, 1:, :, :]
cost = agent.learn(batch_state, batch_action, batch_reward, cost = agent.learn(batch_obs, batch_action, batch_reward,
batch_next_state, batch_isOver) batch_next_obs, batch_isOver)
all_cost.append(cost) all_cost.append(cost)
total_reward += reward total_reward += reward
state = next_state obs = next_obs
if isOver: if isOver:
mean_loss = np.mean(all_cost) if all_cost else None mean_loss = np.mean(all_cost) if all_cost else None
return total_reward, steps, mean_loss return total_reward, steps, mean_loss
def run_evaluate_episode(env, agent): def run_evaluate_episode(env, agent):
state = env.reset() obs = env.reset()
total_reward = 0 total_reward = 0
while True: while True:
pred_Q = agent.predict(state) pred_Q = agent.predict(obs)
action = pred_Q.max(1)[1].item() action = pred_Q.max(1)[1].item()
state, reward, isOver, _ = env.step(action) obs, reward, isOver, _ = env.step(action)
total_reward += reward total_reward += reward
if isOver: if isOver:
return total_reward return total_reward
def get_fixed_states(rpm, batch_size): def get_fixed_obs(rpm, batch_size):
states = [] obs = []
for _ in range(3): for _ in range(3):
batch_all_state = rpm.sample_batch(batch_size)[0] batch_all_obs = rpm.sample_batch(batch_size)[0]
batch_state = batch_all_state[:, :CONTEXT_LEN, :, :] batch_obs = batch_all_obs[:, :CONTEXT_LEN, :, :]
states.append(batch_state) obs.append(batch_obs)
fixed_states = np.concatenate(states, axis=0) fixed_obs = np.concatenate(obs, axis=0)
return fixed_states return fixed_obs
def evaluate_fixed_Q(agent, states): def evaluate_fixed_Q(agent, obs):
with torch.no_grad(): with torch.no_grad():
max_pred_Q = agent.alg.model(states).max(1)[0].mean() max_pred_Q = agent.alg.model(obs).max(1)[0].mean()
return max_pred_Q.item() return max_pred_Q.item()
...@@ -131,9 +131,9 @@ def main(): ...@@ -131,9 +131,9 @@ def main():
total_reward, steps, _ = run_train_episode(env, agent, rpm) total_reward, steps, _ = run_train_episode(env, agent, rpm)
pbar.update(steps) pbar.update(steps)
# Get fixed states to check value function. # Get fixed obs to check value function.
fixed_states = get_fixed_states(rpm, args.batch_size) fixed_obs = get_fixed_obs(rpm, args.batch_size)
fixed_states = torch.tensor(fixed_states, dtype=torch.float, device=device) fixed_obs = torch.tensor(fixed_obs, dtype=torch.float, device=device)
# train # train
test_flag = 0 test_flag = 0
...@@ -159,7 +159,7 @@ def main(): ...@@ -159,7 +159,7 @@ def main():
tensorboard.add_scalar('dqn/exploration', agent.exploration, tensorboard.add_scalar('dqn/exploration', agent.exploration,
total_steps) total_steps)
tensorboard.add_scalar('dqn/Q value', tensorboard.add_scalar('dqn/Q value',
evaluate_fixed_Q(agent, fixed_states), evaluate_fixed_Q(agent, fixed_obs),
total_steps) total_steps)
tensorboard.add_scalar('dqn/grad_norm', tensorboard.add_scalar('dqn/grad_norm',
get_grad_norm(agent.alg.model), get_grad_norm(agent.alg.model),
......
...@@ -21,13 +21,13 @@ from parl import layers ...@@ -21,13 +21,13 @@ from parl import layers
class CartpoleAgent(parl.Agent): class CartpoleAgent(parl.Agent):
def __init__(self, def __init__(self,
algorithm, algorithm,
state_dim, obs_dim,
act_dim, act_dim,
e_greed=0.1, e_greed=0.1,
e_greed_decrement=0): e_greed_decrement=0):
assert isinstance(state_dim, int) assert isinstance(obs_dim, int)
assert isinstance(act_dim, int) assert isinstance(act_dim, int)
self.state_dim = state_dim self.obs_dim = obs_dim
self.act_dim = act_dim self.act_dim = act_dim
super(CartpoleAgent, self).__init__(algorithm) super(CartpoleAgent, self).__init__(algorithm)
...@@ -43,16 +43,16 @@ class CartpoleAgent(parl.Agent): ...@@ -43,16 +43,16 @@ class CartpoleAgent(parl.Agent):
with fluid.program_guard(self.pred_program): with fluid.program_guard(self.pred_program):
obs = layers.data( obs = layers.data(
name='obs', shape=[self.state_dim], dtype='float32') name='obs', shape=[self.obs_dim], dtype='float32')
self.value = self.alg.predict(obs) self.value = self.alg.predict(obs)
with fluid.program_guard(self.learn_program): with fluid.program_guard(self.learn_program):
obs = layers.data( obs = layers.data(
name='obs', shape=[self.state_dim], dtype='float32') name='obs', shape=[self.obs_dim], dtype='float32')
action = layers.data(name='act', shape=[1], dtype='int32') action = layers.data(name='act', shape=[1], dtype='int32')
reward = layers.data(name='reward', shape=[], dtype='float32') reward = layers.data(name='reward', shape=[], dtype='float32')
next_obs = layers.data( next_obs = layers.data(
name='next_obs', shape=[self.state_dim], dtype='float32') name='next_obs', shape=[self.obs_dim], dtype='float32')
terminal = layers.data(name='terminal', shape=[], dtype='bool') terminal = layers.data(name='terminal', shape=[], dtype='bool')
lr = layers.data( lr = layers.data(
name='lr', shape=[1], dtype='float32', append_batch_size=False) name='lr', shape=[1], dtype='float32', append_batch_size=False)
......
...@@ -28,19 +28,19 @@ class ReplayMemory(object): ...@@ -28,19 +28,19 @@ class ReplayMemory(object):
def sample(self, batch_size): def sample(self, batch_size):
mini_batch = random.sample(self.buffer, batch_size) mini_batch = random.sample(self.buffer, batch_size)
state_batch, action_batch, reward_batch, next_state_batch, done_batch = [], [], [], [], [] obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = [], [], [], [], []
for experience in mini_batch: for experience in mini_batch:
s, a, r, s_p, done = experience s, a, r, s_p, done = experience
state_batch.append(s) obs_batch.append(s)
action_batch.append(a) action_batch.append(a)
reward_batch.append(r) reward_batch.append(r)
next_state_batch.append(s_p) next_obs_batch.append(s_p)
done_batch.append(done) done_batch.append(done)
return np.array(state_batch).astype('float32'), \ return np.array(obs_batch).astype('float32'), \
np.array(action_batch).astype('float32'), np.array(reward_batch).astype('float32'),\ np.array(action_batch).astype('float32'), np.array(reward_batch).astype('float32'),\
np.array(next_state_batch).astype('float32'), np.array(done_batch).astype('float32') np.array(next_obs_batch).astype('float32'), np.array(done_batch).astype('float32')
def __len__(self): def __len__(self):
return len(self.buffer) return len(self.buffer)
...@@ -32,24 +32,24 @@ GAMMA = 0.99 # discount factor of reward ...@@ -32,24 +32,24 @@ GAMMA = 0.99 # discount factor of reward
def run_episode(agent, env, rpm): def run_episode(agent, env, rpm):
total_reward = 0 total_reward = 0
state = env.reset() obs = env.reset()
step = 0 step = 0
while True: while True:
step += 1 step += 1
action = agent.sample(state) action = agent.sample(obs)
next_state, reward, isOver, _ = env.step(action) next_obs, reward, isOver, _ = env.step(action)
rpm.append((state, action, reward, next_state, isOver)) rpm.append((obs, action, reward, next_obs, isOver))
# train model # train model
if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0): if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0):
(batch_state, batch_action, batch_reward, batch_next_state, (batch_obs, batch_action, batch_reward, batch_next_obs,
batch_isOver) = rpm.sample(BATCH_SIZE) batch_isOver) = rpm.sample(BATCH_SIZE)
train_loss = agent.learn(batch_state, batch_action, batch_reward, train_loss = agent.learn(batch_obs, batch_action, batch_reward,
batch_next_state, batch_isOver, batch_next_obs, batch_isOver,
LEARNING_RATE) LEARNING_RATE)
total_reward += reward total_reward += reward
state = next_state obs = next_obs
if isOver: if isOver:
break break
return total_reward return total_reward
...@@ -59,14 +59,14 @@ def evaluate(agent, env, render=False): ...@@ -59,14 +59,14 @@ def evaluate(agent, env, render=False):
# test part, run 5 episodes and average # test part, run 5 episodes and average
eval_reward = [] eval_reward = []
for i in range(5): for i in range(5):
state = env.reset() obs = env.reset()
episode_reward = 0 episode_reward = 0
isOver = False isOver = False
while not isOver: while not isOver:
action = agent.predict(state) action = agent.predict(obs)
if render: if render:
env.render() env.render()
state, reward, isOver, _ = env.step(action) obs, reward, isOver, _ = env.step(action)
episode_reward += reward episode_reward += reward
eval_reward.append(episode_reward) eval_reward.append(episode_reward)
return np.mean(eval_reward) return np.mean(eval_reward)
...@@ -75,7 +75,7 @@ def evaluate(agent, env, render=False): ...@@ -75,7 +75,7 @@ def evaluate(agent, env, render=False):
def main(): def main():
env = gym.make('CartPole-v1') env = gym.make('CartPole-v1')
action_dim = env.action_space.n action_dim = env.action_space.n
state_shape = env.observation_space.shape obs_shape = env.observation_space.shape
rpm = ReplayMemory(MEMORY_SIZE) rpm = ReplayMemory(MEMORY_SIZE)
...@@ -83,7 +83,7 @@ def main(): ...@@ -83,7 +83,7 @@ def main():
algorithm = parl.algorithms.DQN(model, act_dim=action_dim, gamma=GAMMA) algorithm = parl.algorithms.DQN(model, act_dim=action_dim, gamma=GAMMA)
agent = CartpoleAgent( agent = CartpoleAgent(
algorithm, algorithm,
state_dim=state_shape[0], obs_dim=obs_shape[0],
act_dim=action_dim, act_dim=action_dim,
e_greed=0.1, # explore e_greed=0.1, # explore
e_greed_decrement=1e-6 e_greed_decrement=1e-6
......
...@@ -16,16 +16,16 @@ import numpy as np ...@@ -16,16 +16,16 @@ import numpy as np
import copy import copy
from collections import deque, namedtuple from collections import deque, namedtuple
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'isOver']) Experience = namedtuple('Experience', ['obs', 'action', 'reward', 'isOver'])
class ReplayMemory(object): class ReplayMemory(object):
def __init__(self, max_size, state_shape, context_len): def __init__(self, max_size, obs_shape, context_len):
self.max_size = int(max_size) self.max_size = int(max_size)
self.state_shape = state_shape self.obs_shape = obs_shape
self.context_len = int(context_len) self.context_len = int(context_len)
self.state = np.zeros((self.max_size, ) + state_shape, dtype='uint8') self.obs = np.zeros((self.max_size, ) + obs_shape, dtype='uint8')
self.action = np.zeros((self.max_size, ), dtype='int32') self.action = np.zeros((self.max_size, ), dtype='int32')
self.reward = np.zeros((self.max_size, ), dtype='float32') self.reward = np.zeros((self.max_size, ), dtype='float32')
self.isOver = np.zeros((self.max_size, ), dtype='bool') self.isOver = np.zeros((self.max_size, ), dtype='bool')
...@@ -48,42 +48,41 @@ class ReplayMemory(object): ...@@ -48,42 +48,41 @@ class ReplayMemory(object):
else: else:
self._context.append(exp) self._context.append(exp)
def recent_state(self): def recent_obs(self):
""" maintain recent state for training""" """ maintain recent obs for training"""
lst = list(self._context) lst = list(self._context)
states = [np.zeros(self.state_shape, dtype='uint8')] * \ obs = [np.zeros(self.obs_shape, dtype='uint8')] * \
(self._context.maxlen - len(lst)) (self._context.maxlen - len(lst))
states.extend([k.state for k in lst]) obs.extend([k.obs for k in lst])
return states return obs
def sample(self, idx): def sample(self, idx):
""" return state, action, reward, isOver, """ return obs, action, reward, isOver,
note that some frames in state may be generated from last episode, note that some frames in obs may be generated from last episode,
they should be removed from state they should be removed from obs
""" """
state = np.zeros( obs = np.zeros(
(self.context_len + 1, ) + self.state_shape, dtype=np.uint8) (self.context_len + 1, ) + self.obs_shape, dtype=np.uint8)
state_idx = np.arange(idx, obs_idx = np.arange(idx, idx + self.context_len + 1) % self._curr_size
idx + self.context_len + 1) % self._curr_size
# confirm that no frame was generated from last episode # confirm that no frame was generated from last episode
has_last_episode = False has_last_episode = False
for k in range(self.context_len - 2, -1, -1): for k in range(self.context_len - 2, -1, -1):
to_check_idx = state_idx[k] to_check_idx = obs_idx[k]
if self.isOver[to_check_idx]: if self.isOver[to_check_idx]:
has_last_episode = True has_last_episode = True
state_idx = state_idx[k + 1:] obs_idx = obs_idx[k + 1:]
state[k + 1:] = self.state[state_idx] obs[k + 1:] = self.obs[obs_idx]
break break
if not has_last_episode: if not has_last_episode:
state = self.state[state_idx] obs = self.obs[obs_idx]
real_idx = (idx + self.context_len - 1) % self._curr_size real_idx = (idx + self.context_len - 1) % self._curr_size
action = self.action[real_idx] action = self.action[real_idx]
reward = self.reward[real_idx] reward = self.reward[real_idx]
isOver = self.isOver[real_idx] isOver = self.isOver[real_idx]
return state, reward, action, isOver return obs, reward, action, isOver
def __len__(self): def __len__(self):
return self._curr_size return self._curr_size
...@@ -92,7 +91,7 @@ class ReplayMemory(object): ...@@ -92,7 +91,7 @@ class ReplayMemory(object):
return self._curr_size return self._curr_size
def _assign(self, pos, exp): def _assign(self, pos, exp):
self.state[pos] = exp.state self.obs[pos] = exp.obs
self.reward[pos] = exp.reward self.reward[pos] = exp.reward
self.action[pos] = exp.action self.action[pos] = exp.action
self.isOver[pos] = exp.isOver self.isOver[pos] = exp.isOver
...@@ -107,8 +106,8 @@ class ReplayMemory(object): ...@@ -107,8 +106,8 @@ class ReplayMemory(object):
return self._process_batch(batch_exp) return self._process_batch(batch_exp)
def _process_batch(self, batch_exp): def _process_batch(self, batch_exp):
state = np.asarray([e[0] for e in batch_exp], dtype='uint8') obs = np.asarray([e[0] for e in batch_exp], dtype='uint8')
reward = np.asarray([e[1] for e in batch_exp], dtype='float32') reward = np.asarray([e[1] for e in batch_exp], dtype='float32')
action = np.asarray([e[2] for e in batch_exp], dtype='int8') action = np.asarray([e[2] for e in batch_exp], dtype='int8')
isOver = np.asarray([e[3] for e in batch_exp], dtype='bool') isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
return [state, action, reward, isOver] return [obs, action, reward, isOver]
...@@ -39,28 +39,28 @@ LEARNING_RATE = 3e-4 ...@@ -39,28 +39,28 @@ LEARNING_RATE = 3e-4
def run_train_episode(env, agent, rpm): def run_train_episode(env, agent, rpm):
total_reward = 0 total_reward = 0
all_cost = [] all_cost = []
state = env.reset() obs = env.reset()
steps = 0 steps = 0
while True: while True:
steps += 1 steps += 1
context = rpm.recent_state() context = rpm.recent_obs()
context.append(state) context.append(obs)
context = np.stack(context, axis=0) context = np.stack(context, axis=0)
action = agent.sample(context) action = agent.sample(context)
next_state, reward, isOver, _ = env.step(action) next_obs, reward, isOver, _ = env.step(action)
rpm.append(Experience(state, action, reward, isOver)) rpm.append(Experience(obs, action, reward, isOver))
# start training # start training
if rpm.size() > MEMORY_WARMUP_SIZE: if rpm.size() > MEMORY_WARMUP_SIZE:
if steps % UPDATE_FREQ == 0: if steps % UPDATE_FREQ == 0:
batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_batch( batch_all_obs, batch_action, batch_reward, batch_isOver = rpm.sample_batch(
args.batch_size) args.batch_size)
batch_state = batch_all_state[:, :CONTEXT_LEN, :, :] batch_obs = batch_all_obs[:, :CONTEXT_LEN, :, :]
batch_next_state = batch_all_state[:, 1:, :, :] batch_next_obs = batch_all_obs[:, 1:, :, :]
cost = agent.learn(batch_state, batch_action, batch_reward, cost = agent.learn(batch_obs, batch_action, batch_reward,
batch_next_state, batch_isOver) batch_next_obs, batch_isOver)
all_cost.append(float(cost)) all_cost.append(float(cost))
total_reward += reward total_reward += reward
state = next_state obs = next_obs
if isOver: if isOver:
break break
if all_cost: if all_cost:
...@@ -70,11 +70,11 @@ def run_train_episode(env, agent, rpm): ...@@ -70,11 +70,11 @@ def run_train_episode(env, agent, rpm):
def run_evaluate_episode(env, agent): def run_evaluate_episode(env, agent):
state = env.reset() obs = env.reset()
total_reward = 0 total_reward = 0
while True: while True:
action = agent.predict(state) action = agent.predict(obs)
state, reward, isOver, info = env.step(action) obs, reward, isOver, info = env.step(action)
total_reward += reward total_reward += reward
if isOver: if isOver:
break break
......
../DQN/atari.py ../DQN_variant/atari.py
\ No newline at end of file \ No newline at end of file
../DQN/atari_wrapper.py ../DQN_variant/atari_wrapper.py
\ No newline at end of file \ No newline at end of file
...@@ -45,21 +45,21 @@ gpu_num = get_gpu_count() ...@@ -45,21 +45,21 @@ gpu_num = get_gpu_count()
def run_train_step(agent, rpm): def run_train_step(agent, rpm):
for step in range(args.train_total_steps): for step in range(args.train_total_steps):
# use the first 80% data to train # use the first 80% data to train
batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_batch( batch_all_obs, batch_action, batch_reward, batch_isOver = rpm.sample_batch(
args.batch_size * gpu_num) args.batch_size * gpu_num)
batch_state = batch_all_state[:, :CONTEXT_LEN, :, :] batch_obs = batch_all_obs[:, :CONTEXT_LEN, :, :]
batch_next_state = batch_all_state[:, 1:, :, :] batch_next_obs = batch_all_obs[:, 1:, :, :]
cost = agent.learn(batch_state, batch_action, batch_reward, cost = agent.learn(batch_obs, batch_action, batch_reward,
batch_next_state, batch_isOver) batch_next_obs, batch_isOver)
if step % 100 == 0: if step % 100 == 0:
# use the last 20% data to evaluate # use the last 20% data to evaluate
batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_test_batch( batch_all_obs, batch_action, batch_reward, batch_isOver = rpm.sample_test_batch(
args.batch_size) args.batch_size)
batch_state = batch_all_state[:, :CONTEXT_LEN, :, :] batch_obs = batch_all_obs[:, :CONTEXT_LEN, :, :]
batch_next_state = batch_all_state[:, 1:, :, :] batch_next_obs = batch_all_obs[:, 1:, :, :]
eval_cost = agent.supervised_eval(batch_state, batch_action, eval_cost = agent.supervised_eval(batch_obs, batch_action,
batch_reward, batch_next_state, batch_reward, batch_next_obs,
batch_isOver) batch_isOver)
logger.info( logger.info(
"train step {}, train costs are {}, eval cost is {}.".format( "train step {}, train costs are {}, eval cost is {}.".format(
...@@ -67,17 +67,17 @@ def run_train_step(agent, rpm): ...@@ -67,17 +67,17 @@ def run_train_step(agent, rpm):
def collect_exp(env, rpm, agent): def collect_exp(env, rpm, agent):
state = env.reset() obs = env.reset()
# collect data to fulfill replay memory # collect data to fulfill replay memory
for i in tqdm(range(MEMORY_SIZE)): for i in tqdm(range(MEMORY_SIZE)):
context = rpm.recent_state() context = rpm.recent_obs()
context.append(state) context.append(obs)
context = np.stack(context, axis=0) context = np.stack(context, axis=0)
action = agent.sample(context) action = agent.sample(context)
next_state, reward, isOver, _ = env.step(action) next_obs, reward, isOver, _ = env.step(action)
rpm.append(Experience(state, action, reward, isOver)) rpm.append(Experience(obs, action, reward, isOver))
state = next_state obs = next_obs
def main(): def main():
......
...@@ -18,18 +18,18 @@ import os ...@@ -18,18 +18,18 @@ import os
from collections import deque, namedtuple from collections import deque, namedtuple
from parl.utils import logger from parl.utils import logger
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'isOver']) Experience = namedtuple('Experience', ['obs', 'action', 'reward', 'isOver'])
class ReplayMemory(object): class ReplayMemory(object):
def __init__(self, def __init__(self,
max_size, max_size,
state_shape, obs_shape,
context_len, context_len,
load_file=False, load_file=False,
file_path=None): file_path=None):
self.max_size = int(max_size) self.max_size = int(max_size)
self.state_shape = state_shape self.obs_shape = obs_shape
self.context_len = int(context_len) self.context_len = int(context_len)
self.file_path = file_path self.file_path = file_path
...@@ -38,8 +38,7 @@ class ReplayMemory(object): ...@@ -38,8 +38,7 @@ class ReplayMemory(object):
self.load_memory() self.load_memory()
logger.info("memory size is {}".format(self._curr_size)) logger.info("memory size is {}".format(self._curr_size))
else: else:
self.state = np.zeros( self.obs = np.zeros((self.max_size, ) + obs_shape, dtype='uint8')
(self.max_size, ) + state_shape, dtype='uint8')
self.action = np.zeros((self.max_size, ), dtype='int32') self.action = np.zeros((self.max_size, ), dtype='int32')
self.reward = np.zeros((self.max_size, ), dtype='float32') self.reward = np.zeros((self.max_size, ), dtype='float32')
self.isOver = np.zeros((self.max_size, ), dtype='bool') self.isOver = np.zeros((self.max_size, ), dtype='bool')
...@@ -62,42 +61,41 @@ class ReplayMemory(object): ...@@ -62,42 +61,41 @@ class ReplayMemory(object):
else: else:
self._context.append(exp) self._context.append(exp)
def recent_state(self): def recent_obs(self):
""" maintain recent state for training""" """ maintain recent obs for training"""
lst = list(self._context) lst = list(self._context)
states = [np.zeros(self.state_shape, dtype='uint8')] * \ obs = [np.zeros(self.obs_shape, dtype='uint8')] * \
(self._context.maxlen - len(lst)) (self._context.maxlen - len(lst))
states.extend([k.state for k in lst]) obs.extend([k.obs for k in lst])
return states return obs
def sample(self, idx): def sample(self, idx):
""" return state, action, reward, isOver, """ return obs, action, reward, isOver,
note that some frames in state may be generated from last episode, note that some frames in obs may be generated from last episode,
they should be removed from state they should be removed from obs
""" """
state = np.zeros( obs = np.zeros(
(self.context_len + 1, ) + self.state_shape, dtype=np.uint8) (self.context_len + 1, ) + self.obs_shape, dtype=np.uint8)
state_idx = np.arange(idx, obs_idx = np.arange(idx, idx + self.context_len + 1) % self._curr_size
idx + self.context_len + 1) % self._curr_size
# confirm that no frame was generated from last episode # confirm that no frame was generated from last episode
has_last_episode = False has_last_episode = False
for k in range(self.context_len - 2, -1, -1): for k in range(self.context_len - 2, -1, -1):
to_check_idx = state_idx[k] to_check_idx = obs_idx[k]
if self.isOver[to_check_idx]: if self.isOver[to_check_idx]:
has_last_episode = True has_last_episode = True
state_idx = state_idx[k + 1:] obs_idx = obs_idx[k + 1:]
state[k + 1:] = self.state[state_idx] obs[k + 1:] = self.obs[obs_idx]
break break
if not has_last_episode: if not has_last_episode:
state = self.state[state_idx] obs = self.obs[obs_idx]
real_idx = (idx + self.context_len - 1) % self._curr_size real_idx = (idx + self.context_len - 1) % self._curr_size
action = self.action[real_idx] action = self.action[real_idx]
reward = self.reward[real_idx] reward = self.reward[real_idx]
isOver = self.isOver[real_idx] isOver = self.isOver[real_idx]
return state, reward, action, isOver return obs, reward, action, isOver
def __len__(self): def __len__(self):
return self._curr_size return self._curr_size
...@@ -106,7 +104,7 @@ class ReplayMemory(object): ...@@ -106,7 +104,7 @@ class ReplayMemory(object):
return self._curr_size return self._curr_size
def _assign(self, pos, exp): def _assign(self, pos, exp):
self.state[pos] = exp.state self.obs[pos] = exp.obs
self.reward[pos] = exp.reward self.reward[pos] = exp.reward
self.action[pos] = exp.action self.action[pos] = exp.action
self.isOver[pos] = exp.isOver self.isOver[pos] = exp.isOver
...@@ -129,15 +127,15 @@ class ReplayMemory(object): ...@@ -129,15 +127,15 @@ class ReplayMemory(object):
return self._process_batch(batch_exp) return self._process_batch(batch_exp)
def _process_batch(self, batch_exp): def _process_batch(self, batch_exp):
state = np.asarray([e[0] for e in batch_exp], dtype='uint8') obs = np.asarray([e[0] for e in batch_exp], dtype='uint8')
reward = np.asarray([e[1] for e in batch_exp], dtype='float32') reward = np.asarray([e[1] for e in batch_exp], dtype='float32')
action = np.asarray([e[2] for e in batch_exp], dtype='int8') action = np.asarray([e[2] for e in batch_exp], dtype='int8')
isOver = np.asarray([e[3] for e in batch_exp], dtype='bool') isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
return [state, action, reward, isOver] return [obs, action, reward, isOver]
def save_memory(self): def save_memory(self):
save_data = [ save_data = [
self.state, self.reward, self.action, self.isOver, self._curr_size, self.obs, self.reward, self.action, self.isOver, self._curr_size,
self._curr_pos, self._context self._curr_pos, self._context
] ]
np.savez(self.file_path, *save_data) np.savez(self.file_path, *save_data)
...@@ -145,7 +143,7 @@ class ReplayMemory(object): ...@@ -145,7 +143,7 @@ class ReplayMemory(object):
def load_memory(self): def load_memory(self):
container = np.load(self.file_path, allow_pickle=True) container = np.load(self.file_path, allow_pickle=True)
[ [
self.state, self.reward, self.action, self.isOver, self._curr_size, self.obs, self.reward, self.action, self.isOver, self._curr_size,
self._curr_pos, self._context self._curr_pos, self._context
] = [container[key] for key in container] ] = [container[key] for key in container]
self._curr_size = self._curr_size.astype(int) self._curr_size = self._curr_size.astype(int)
......
../DQN/rom_files/ ../DQN_variant/rom_files
\ No newline at end of file \ No newline at end of file
../DQN/utils.py ../DQN_variant/utils.py
\ No newline at end of file \ No newline at end of file
...@@ -102,11 +102,11 @@ class SAC(Algorithm): ...@@ -102,11 +102,11 @@ class SAC(Algorithm):
return cost return cost
def critic_learn(self, obs, action, reward, next_obs, terminal): def critic_learn(self, obs, action, reward, next_obs, terminal):
next_state_action, next_state_log_pi = self.sample(next_obs) next_obs_action, next_obs_log_pi = self.sample(next_obs)
qf1_next_target, qf2_next_target = self.target_critic.value( qf1_next_target, qf2_next_target = self.target_critic.value(
next_obs, next_state_action) next_obs, next_obs_action)
min_qf_next_target = layers.elementwise_min( min_qf_next_target = layers.elementwise_min(
qf1_next_target, qf2_next_target) - next_state_log_pi * self.alpha qf1_next_target, qf2_next_target) - next_obs_log_pi * self.alpha
terminal = layers.cast(terminal, dtype='float32') terminal = layers.cast(terminal, dtype='float32')
target_Q = reward + (1.0 - terminal) * self.gamma * min_qf_next_target target_Q = reward + (1.0 - terminal) * self.gamma * min_qf_next_target
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册