From 2e56337efd8b23ed3bc2eeddea5260a591c12d0a Mon Sep 17 00:00:00 2001 From: Bo Zhou <2466956298@qq.com> Date: Wed, 24 Jun 2020 05:36:43 -0500 Subject: [PATCH] support paddle 1.8.2 (#317) --- examples/DQN_variant/train.py | 2 +- parl/algorithms/fluid/ddqn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/DQN_variant/train.py b/examples/DQN_variant/train.py index 5ca16df..e2ef720 100644 --- a/examples/DQN_variant/train.py +++ b/examples/DQN_variant/train.py @@ -93,7 +93,7 @@ def main(): act_dim = env.action_space.n model = AtariModel(act_dim, args.algo) - if args.algo == 'Double': + if args.algo == 'DDQN': algorithm = parl.algorithms.DDQN(model, act_dim=act_dim, gamma=GAMMA) elif args.algo in ['DQN', 'Dueling']: algorithm = parl.algorithms.DQN(model, act_dim=act_dim, gamma=GAMMA) diff --git a/parl/algorithms/fluid/ddqn.py b/parl/algorithms/fluid/ddqn.py index 03c0ced..e2fa7dc 100644 --- a/parl/algorithms/fluid/ddqn.py +++ b/parl/algorithms/fluid/ddqn.py @@ -75,7 +75,7 @@ class DDQN(Algorithm): greedy_action = layers.argmax(next_action_value, axis=-1) # calculate the target q value with target network - batch_size = layers.cast(layers.shape(greedy_action)[0], dtype='int') + batch_size = layers.cast(layers.shape(greedy_action)[0], dtype='int32') range_tmp = layers.range( start=0, end=batch_size, step=1, dtype='int64') * self.act_dim a_indices = range_tmp + greedy_action -- GitLab