diff --git a/examples/DQN/cartpole_agent.py b/examples/DQN/cartpole_agent.py index c7d16cf43a1729f7cfd62f0981679ec373dd77ae..d98f2ba7cdbd426754e3103ebd4068d5e9fb9871 100755 --- a/examples/DQN/cartpole_agent.py +++ b/examples/DQN/cartpole_agent.py @@ -54,10 +54,7 @@ class CartpoleAgent(parl.Agent): next_obs = layers.data( name='next_obs', shape=[self.obs_dim], dtype='float32') terminal = layers.data(name='terminal', shape=[], dtype='bool') - lr = layers.data( - name='lr', shape=[1], dtype='float32', append_batch_size=False) - self.cost = self.alg.learn(obs, action, reward, next_obs, terminal, - lr) + self.cost = self.alg.learn(obs, action, reward, next_obs, terminal) def sample(self, obs): sample = np.random.rand() @@ -78,7 +75,7 @@ class CartpoleAgent(parl.Agent): act = np.argmax(pred_Q) return act - def learn(self, obs, act, reward, next_obs, terminal, lr): + def learn(self, obs, act, reward, next_obs, terminal): if self.global_step % self.update_target_steps == 0: self.alg.sync_target() self.global_step += 1 @@ -90,7 +87,6 @@ class CartpoleAgent(parl.Agent): 'reward': reward, 'next_obs': next_obs.astype('float32'), 'terminal': terminal, - 'lr': np.float32([lr]), } cost = self.fluid_executor.run( self.learn_program, feed=feed, fetch_list=[self.cost])[0] diff --git a/examples/DQN/train.py b/examples/DQN/train.py index 871b50a8d6a734aad4d63011cb9d2fe4425e8384..b634b122eff4abc7177ff830387560c95fb2aa2b 100755 --- a/examples/DQN/train.py +++ b/examples/DQN/train.py @@ -45,8 +45,7 @@ def run_episode(agent, env, rpm): (batch_obs, batch_action, batch_reward, batch_next_obs, batch_isOver) = rpm.sample(BATCH_SIZE) train_loss = agent.learn(batch_obs, batch_action, batch_reward, - batch_next_obs, batch_isOver, - LEARNING_RATE) + batch_next_obs, batch_isOver) total_reward += reward obs = next_obs @@ -80,7 +79,8 @@ def main(): rpm = ReplayMemory(MEMORY_SIZE) model = CartpoleModel(act_dim=action_dim) - algorithm = parl.algorithms.DQN(model, act_dim=action_dim, gamma=GAMMA) + algorithm = parl.algorithms.DQN( + model, act_dim=action_dim, gamma=GAMMA, lr=LEARNING_RATE) agent = CartpoleAgent( algorithm, obs_dim=obs_shape[0], diff --git a/parl/algorithms/fluid/ddqn.py b/parl/algorithms/fluid/ddqn.py index f81ba6024e2f2096358424826a5bac772dc1b183..03c0ced5019abcef00151a68eca944b32caa8469 100644 --- a/parl/algorithms/fluid/ddqn.py +++ b/parl/algorithms/fluid/ddqn.py @@ -21,19 +21,17 @@ import paddle.fluid as fluid from parl.core.fluid.algorithm import Algorithm from parl.core.fluid import layers +__all__ = ['DDQN'] + class DDQN(Algorithm): - def __init__( - self, - model, - act_dim=None, - gamma=None, - ): + def __init__(self, model, act_dim=None, gamma=None, lr=None): """ Double DQN algorithm - Args: - model (parl.Model): model defining forward network of Q function. + model (parl.Model): model defining forward network of Q function + act_dim (int): dimension of the action space gamma (float): discounted factor for reward computation. + lr (float): learning rate. """ self.model = model self.target_model = copy.deepcopy(model) @@ -43,11 +41,29 @@ class DDQN(Algorithm): self.act_dim = act_dim self.gamma = gamma + self.lr = lr def predict(self, obs): + """ use value model self.model to predict the action value + """ return self.model.value(obs) - def learn(self, obs, action, reward, next_obs, terminal, learning_rate): + def learn(self, + obs, + action, + reward, + next_obs, + terminal, + learning_rate=None): + """ update value model self.model with DQN algorithm + """ + # Support the modification of learning_rate + if learning_rate is None: + assert isinstance( + self.lr, + float), "Please set the learning rate of DQN in initializaion." + learning_rate = self.lr + pred_value = self.model.value(obs) action_onehot = layers.one_hot(action, self.act_dim) action_onehot = layers.cast(action_onehot, dtype='float32') diff --git a/parl/algorithms/fluid/dqn.py b/parl/algorithms/fluid/dqn.py index ed5f907a4457eaa25e86f662f685bc013f526b7b..56d05e0a67cf5d6653bba2e350a71bb08977733a 100644 --- a/parl/algorithms/fluid/dqn.py +++ b/parl/algorithms/fluid/dqn.py @@ -24,7 +24,7 @@ __all__ = ['DQN'] class DQN(Algorithm): - def __init__(self, model, act_dim=None, gamma=None): + def __init__(self, model, act_dim=None, gamma=None, lr=None): """ DQN algorithm Args: @@ -38,17 +38,31 @@ class DQN(Algorithm): assert isinstance(act_dim, int) assert isinstance(gamma, float) + self.act_dim = act_dim self.gamma = gamma + self.lr = lr def predict(self, obs): """ use value model self.model to predict the action value """ return self.model.value(obs) - def learn(self, obs, action, reward, next_obs, terminal, learning_rate): + def learn(self, + obs, + action, + reward, + next_obs, + terminal, + learning_rate=None): """ update value model self.model with DQN algorithm """ + # Support the modification of learning_rate + if learning_rate is None: + assert isinstance( + self.lr, + float), "Please set the learning rate of DQN in initializaion." + learning_rate = self.lr pred_value = self.model.value(obs) next_pred_value = self.target_model.value(next_obs) diff --git a/parl/utils/utils.py b/parl/utils/utils.py index 02c8ef2c2f0a1375ee00f0d0a8cc026039009767..a29a8c825017f9241ea59f76f5fe5e58de4f7b80 100644 --- a/parl/utils/utils.py +++ b/parl/utils/utils.py @@ -89,7 +89,7 @@ MAX_INT32 = 0x7fffffff try: from paddle import fluid fluid_version = get_fluid_version() - assert fluid_version >= 161, "PARL requires paddle>=1.6.1" + assert fluid_version >= 161 or fluid_version == 0, "PARL requires paddle>=1.6.1" _HAS_FLUID = True except ImportError: _HAS_FLUID = False