未验证 提交 a9159021 编写于 作者: R rical730 提交者: GitHub

upgrade DQN's lr interface compatibility (#291)

* upgrade DQN's lr interface compatibility

* yapf

* update example DQN
上级 533d4b2c
......@@ -54,10 +54,7 @@ class CartpoleAgent(parl.Agent):
next_obs = layers.data(
name='next_obs', shape=[self.obs_dim], dtype='float32')
terminal = layers.data(name='terminal', shape=[], dtype='bool')
lr = layers.data(
name='lr', shape=[1], dtype='float32', append_batch_size=False)
self.cost = self.alg.learn(obs, action, reward, next_obs, terminal,
lr)
self.cost = self.alg.learn(obs, action, reward, next_obs, terminal)
def sample(self, obs):
sample = np.random.rand()
......@@ -78,7 +75,7 @@ class CartpoleAgent(parl.Agent):
act = np.argmax(pred_Q)
return act
def learn(self, obs, act, reward, next_obs, terminal, lr):
def learn(self, obs, act, reward, next_obs, terminal):
if self.global_step % self.update_target_steps == 0:
self.alg.sync_target()
self.global_step += 1
......@@ -90,7 +87,6 @@ class CartpoleAgent(parl.Agent):
'reward': reward,
'next_obs': next_obs.astype('float32'),
'terminal': terminal,
'lr': np.float32([lr]),
}
cost = self.fluid_executor.run(
self.learn_program, feed=feed, fetch_list=[self.cost])[0]
......
......@@ -45,8 +45,7 @@ def run_episode(agent, env, rpm):
(batch_obs, batch_action, batch_reward, batch_next_obs,
batch_isOver) = rpm.sample(BATCH_SIZE)
train_loss = agent.learn(batch_obs, batch_action, batch_reward,
batch_next_obs, batch_isOver,
LEARNING_RATE)
batch_next_obs, batch_isOver)
total_reward += reward
obs = next_obs
......@@ -80,7 +79,8 @@ def main():
rpm = ReplayMemory(MEMORY_SIZE)
model = CartpoleModel(act_dim=action_dim)
algorithm = parl.algorithms.DQN(model, act_dim=action_dim, gamma=GAMMA)
algorithm = parl.algorithms.DQN(
model, act_dim=action_dim, gamma=GAMMA, lr=LEARNING_RATE)
agent = CartpoleAgent(
algorithm,
obs_dim=obs_shape[0],
......
......@@ -21,19 +21,17 @@ import paddle.fluid as fluid
from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid import layers
__all__ = ['DDQN']
class DDQN(Algorithm):
def __init__(
self,
model,
act_dim=None,
gamma=None,
):
def __init__(self, model, act_dim=None, gamma=None, lr=None):
""" Double DQN algorithm
Args:
model (parl.Model): model defining forward network of Q function.
model (parl.Model): model defining forward network of Q function
act_dim (int): dimension of the action space
gamma (float): discounted factor for reward computation.
lr (float): learning rate.
"""
self.model = model
self.target_model = copy.deepcopy(model)
......@@ -43,11 +41,29 @@ class DDQN(Algorithm):
self.act_dim = act_dim
self.gamma = gamma
self.lr = lr
def predict(self, obs):
""" use value model self.model to predict the action value
"""
return self.model.value(obs)
def learn(self, obs, action, reward, next_obs, terminal, learning_rate):
def learn(self,
obs,
action,
reward,
next_obs,
terminal,
learning_rate=None):
""" update value model self.model with DQN algorithm
"""
# Support the modification of learning_rate
if learning_rate is None:
assert isinstance(
self.lr,
float), "Please set the learning rate of DQN in initializaion."
learning_rate = self.lr
pred_value = self.model.value(obs)
action_onehot = layers.one_hot(action, self.act_dim)
action_onehot = layers.cast(action_onehot, dtype='float32')
......
......@@ -24,7 +24,7 @@ __all__ = ['DQN']
class DQN(Algorithm):
def __init__(self, model, act_dim=None, gamma=None):
def __init__(self, model, act_dim=None, gamma=None, lr=None):
""" DQN algorithm
Args:
......@@ -38,17 +38,31 @@ class DQN(Algorithm):
assert isinstance(act_dim, int)
assert isinstance(gamma, float)
self.act_dim = act_dim
self.gamma = gamma
self.lr = lr
def predict(self, obs):
""" use value model self.model to predict the action value
"""
return self.model.value(obs)
def learn(self, obs, action, reward, next_obs, terminal, learning_rate):
def learn(self,
obs,
action,
reward,
next_obs,
terminal,
learning_rate=None):
""" update value model self.model with DQN algorithm
"""
# Support the modification of learning_rate
if learning_rate is None:
assert isinstance(
self.lr,
float), "Please set the learning rate of DQN in initializaion."
learning_rate = self.lr
pred_value = self.model.value(obs)
next_pred_value = self.target_model.value(next_obs)
......
......@@ -89,7 +89,7 @@ MAX_INT32 = 0x7fffffff
try:
from paddle import fluid
fluid_version = get_fluid_version()
assert fluid_version >= 161, "PARL requires paddle>=1.6.1"
assert fluid_version >= 161 or fluid_version == 0, "PARL requires paddle>=1.6.1"
_HAS_FLUID = True
except ImportError:
_HAS_FLUID = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册