diff --git a/examples/DQN_variant/train.py b/examples/DQN_variant/train.py index 5ca16df135c346a08f87efc6a694e1e289b8192c..e2ef7209d4e48bde0435605e275c799cef62b15d 100644 --- a/examples/DQN_variant/train.py +++ b/examples/DQN_variant/train.py @@ -93,7 +93,7 @@ def main(): act_dim = env.action_space.n model = AtariModel(act_dim, args.algo) - if args.algo == 'Double': + if args.algo == 'DDQN': algorithm = parl.algorithms.DDQN(model, act_dim=act_dim, gamma=GAMMA) elif args.algo in ['DQN', 'Dueling']: algorithm = parl.algorithms.DQN(model, act_dim=act_dim, gamma=GAMMA) diff --git a/parl/algorithms/fluid/ddqn.py b/parl/algorithms/fluid/ddqn.py index 03c0ced5019abcef00151a68eca944b32caa8469..e2fa7dc951b27b19129804c16c38d53b97ac0b1f 100644 --- a/parl/algorithms/fluid/ddqn.py +++ b/parl/algorithms/fluid/ddqn.py @@ -75,7 +75,7 @@ class DDQN(Algorithm): greedy_action = layers.argmax(next_action_value, axis=-1) # calculate the target q value with target network - batch_size = layers.cast(layers.shape(greedy_action)[0], dtype='int') + batch_size = layers.cast(layers.shape(greedy_action)[0], dtype='int32') range_tmp = layers.range( start=0, end=batch_size, step=1, dtype='int64') * self.act_dim a_indices = range_tmp + greedy_action