未验证 提交 a9159021 编写于 作者: R rical730 提交者: GitHub

upgrade DQN's lr interface compatibility (#291)

* upgrade DQN's lr interface compatibility

* yapf

* update example DQN
上级 533d4b2c
...@@ -54,10 +54,7 @@ class CartpoleAgent(parl.Agent): ...@@ -54,10 +54,7 @@ class CartpoleAgent(parl.Agent):
next_obs = layers.data( next_obs = layers.data(
name='next_obs', shape=[self.obs_dim], dtype='float32') name='next_obs', shape=[self.obs_dim], dtype='float32')
terminal = layers.data(name='terminal', shape=[], dtype='bool') terminal = layers.data(name='terminal', shape=[], dtype='bool')
lr = layers.data( self.cost = self.alg.learn(obs, action, reward, next_obs, terminal)
name='lr', shape=[1], dtype='float32', append_batch_size=False)
self.cost = self.alg.learn(obs, action, reward, next_obs, terminal,
lr)
def sample(self, obs): def sample(self, obs):
sample = np.random.rand() sample = np.random.rand()
...@@ -78,7 +75,7 @@ class CartpoleAgent(parl.Agent): ...@@ -78,7 +75,7 @@ class CartpoleAgent(parl.Agent):
act = np.argmax(pred_Q) act = np.argmax(pred_Q)
return act return act
def learn(self, obs, act, reward, next_obs, terminal, lr): def learn(self, obs, act, reward, next_obs, terminal):
if self.global_step % self.update_target_steps == 0: if self.global_step % self.update_target_steps == 0:
self.alg.sync_target() self.alg.sync_target()
self.global_step += 1 self.global_step += 1
...@@ -90,7 +87,6 @@ class CartpoleAgent(parl.Agent): ...@@ -90,7 +87,6 @@ class CartpoleAgent(parl.Agent):
'reward': reward, 'reward': reward,
'next_obs': next_obs.astype('float32'), 'next_obs': next_obs.astype('float32'),
'terminal': terminal, 'terminal': terminal,
'lr': np.float32([lr]),
} }
cost = self.fluid_executor.run( cost = self.fluid_executor.run(
self.learn_program, feed=feed, fetch_list=[self.cost])[0] self.learn_program, feed=feed, fetch_list=[self.cost])[0]
......
...@@ -45,8 +45,7 @@ def run_episode(agent, env, rpm): ...@@ -45,8 +45,7 @@ def run_episode(agent, env, rpm):
(batch_obs, batch_action, batch_reward, batch_next_obs, (batch_obs, batch_action, batch_reward, batch_next_obs,
batch_isOver) = rpm.sample(BATCH_SIZE) batch_isOver) = rpm.sample(BATCH_SIZE)
train_loss = agent.learn(batch_obs, batch_action, batch_reward, train_loss = agent.learn(batch_obs, batch_action, batch_reward,
batch_next_obs, batch_isOver, batch_next_obs, batch_isOver)
LEARNING_RATE)
total_reward += reward total_reward += reward
obs = next_obs obs = next_obs
...@@ -80,7 +79,8 @@ def main(): ...@@ -80,7 +79,8 @@ def main():
rpm = ReplayMemory(MEMORY_SIZE) rpm = ReplayMemory(MEMORY_SIZE)
model = CartpoleModel(act_dim=action_dim) model = CartpoleModel(act_dim=action_dim)
algorithm = parl.algorithms.DQN(model, act_dim=action_dim, gamma=GAMMA) algorithm = parl.algorithms.DQN(
model, act_dim=action_dim, gamma=GAMMA, lr=LEARNING_RATE)
agent = CartpoleAgent( agent = CartpoleAgent(
algorithm, algorithm,
obs_dim=obs_shape[0], obs_dim=obs_shape[0],
......
...@@ -21,19 +21,17 @@ import paddle.fluid as fluid ...@@ -21,19 +21,17 @@ import paddle.fluid as fluid
from parl.core.fluid.algorithm import Algorithm from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid import layers from parl.core.fluid import layers
__all__ = ['DDQN']
class DDQN(Algorithm): class DDQN(Algorithm):
def __init__( def __init__(self, model, act_dim=None, gamma=None, lr=None):
self,
model,
act_dim=None,
gamma=None,
):
""" Double DQN algorithm """ Double DQN algorithm
Args: Args:
model (parl.Model): model defining forward network of Q function. model (parl.Model): model defining forward network of Q function
act_dim (int): dimension of the action space
gamma (float): discounted factor for reward computation. gamma (float): discounted factor for reward computation.
lr (float): learning rate.
""" """
self.model = model self.model = model
self.target_model = copy.deepcopy(model) self.target_model = copy.deepcopy(model)
...@@ -43,11 +41,29 @@ class DDQN(Algorithm): ...@@ -43,11 +41,29 @@ class DDQN(Algorithm):
self.act_dim = act_dim self.act_dim = act_dim
self.gamma = gamma self.gamma = gamma
self.lr = lr
def predict(self, obs): def predict(self, obs):
""" use value model self.model to predict the action value
"""
return self.model.value(obs) return self.model.value(obs)
def learn(self, obs, action, reward, next_obs, terminal, learning_rate): def learn(self,
obs,
action,
reward,
next_obs,
terminal,
learning_rate=None):
""" update value model self.model with DQN algorithm
"""
# Support the modification of learning_rate
if learning_rate is None:
assert isinstance(
self.lr,
float), "Please set the learning rate of DQN in initializaion."
learning_rate = self.lr
pred_value = self.model.value(obs) pred_value = self.model.value(obs)
action_onehot = layers.one_hot(action, self.act_dim) action_onehot = layers.one_hot(action, self.act_dim)
action_onehot = layers.cast(action_onehot, dtype='float32') action_onehot = layers.cast(action_onehot, dtype='float32')
......
...@@ -24,7 +24,7 @@ __all__ = ['DQN'] ...@@ -24,7 +24,7 @@ __all__ = ['DQN']
class DQN(Algorithm): class DQN(Algorithm):
def __init__(self, model, act_dim=None, gamma=None): def __init__(self, model, act_dim=None, gamma=None, lr=None):
""" DQN algorithm """ DQN algorithm
Args: Args:
...@@ -38,17 +38,31 @@ class DQN(Algorithm): ...@@ -38,17 +38,31 @@ class DQN(Algorithm):
assert isinstance(act_dim, int) assert isinstance(act_dim, int)
assert isinstance(gamma, float) assert isinstance(gamma, float)
self.act_dim = act_dim self.act_dim = act_dim
self.gamma = gamma self.gamma = gamma
self.lr = lr
def predict(self, obs): def predict(self, obs):
""" use value model self.model to predict the action value """ use value model self.model to predict the action value
""" """
return self.model.value(obs) return self.model.value(obs)
def learn(self, obs, action, reward, next_obs, terminal, learning_rate): def learn(self,
obs,
action,
reward,
next_obs,
terminal,
learning_rate=None):
""" update value model self.model with DQN algorithm """ update value model self.model with DQN algorithm
""" """
# Support the modification of learning_rate
if learning_rate is None:
assert isinstance(
self.lr,
float), "Please set the learning rate of DQN in initializaion."
learning_rate = self.lr
pred_value = self.model.value(obs) pred_value = self.model.value(obs)
next_pred_value = self.target_model.value(next_obs) next_pred_value = self.target_model.value(next_obs)
......
...@@ -89,7 +89,7 @@ MAX_INT32 = 0x7fffffff ...@@ -89,7 +89,7 @@ MAX_INT32 = 0x7fffffff
try: try:
from paddle import fluid from paddle import fluid
fluid_version = get_fluid_version() fluid_version = get_fluid_version()
assert fluid_version >= 161, "PARL requires paddle>=1.6.1" assert fluid_version >= 161 or fluid_version == 0, "PARL requires paddle>=1.6.1"
_HAS_FLUID = True _HAS_FLUID = True
except ImportError: except ImportError:
_HAS_FLUID = False _HAS_FLUID = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册