未验证 提交 942e570c 编写于 作者: Z zenghsh3 提交者: GitHub

fix bug of ddqn (#353)

上级 779b5d4e
...@@ -75,7 +75,7 @@ class DDQN(Algorithm): ...@@ -75,7 +75,7 @@ class DDQN(Algorithm):
greedy_action = layers.argmax(next_action_value, axis=-1) greedy_action = layers.argmax(next_action_value, axis=-1)
# calculate the target q value with target network # calculate the target q value with target network
batch_size = layers.cast(layers.shape(greedy_action)[0], dtype='int32') batch_size = layers.cast(layers.shape(greedy_action)[0], dtype='int64')
range_tmp = layers.range( range_tmp = layers.range(
start=0, end=batch_size, step=1, dtype='int64') * self.act_dim start=0, end=batch_size, step=1, dtype='int64') * self.act_dim
a_indices = range_tmp + greedy_action a_indices = range_tmp + greedy_action
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册