未验证 提交 f46ad361 编写于 作者: H Hongsheng Zeng 提交者: GitHub

fix a2c cannot run in paddle 1.6.0 (#232)

* fix a2c cannot run in paddle 1.6.0

* fix impala compatibility

* yapf
上级 8c9bf1fa
......@@ -71,7 +71,10 @@ class AtariAgent(parl.Agent):
lr = layers.data(
name='lr', shape=[1], dtype='float32', append_batch_size=False)
entropy_coeff = layers.data(
name='entropy_coeff', shape=[], dtype='float32')
name='entropy_coeff',
shape=[1],
dtype='float32',
append_batch_size=False)
total_loss, pi_loss, vf_loss, entropy = self.alg.learn(
obs, actions, advantages, target_values, lr, entropy_coeff)
......
......@@ -58,7 +58,10 @@ class AtariAgent(parl.Agent):
lr = layers.data(
name='lr', shape=[1], dtype='float32', append_batch_size=False)
entropy_coeff = layers.data(
name='entropy_coeff', shape=[], dtype='float32')
name='entropy_coeff',
shape=[1],
dtype='float32',
append_batch_size=False)
self.learn_reader = fluid.layers.create_py_reader_by_data(
capacity=32,
......
......@@ -123,7 +123,7 @@ class Learner(object):
obs_np, actions_np, behaviour_logits_np, rewards_np,
dones_np,
np.float32(self.lr),
np.float32(self.entropy_coeff)
np.array([self.entropy_coeff], dtype='float32')
]
def run_learn(self):
......
......@@ -67,7 +67,10 @@ class LiftAgent(parl.Agent):
lr = layers.data(
name='lr', shape=[1], dtype='float32', append_batch_size=False)
entropy_coeff = layers.data(
name='entropy_coeff', shape=[], dtype='float32')
name='entropy_coeff',
shape=[1],
dtype='float32',
append_batch_size=False)
total_loss, pi_loss, vf_loss, entropy = self.alg.learn(
obs, actions, advantages, target_values, lr, entropy_coeff)
......
......@@ -72,7 +72,6 @@ class A3C(Algorithm):
policy_entropy = policy_distribution.entropy()
entropy = layers.reduce_sum(policy_entropy)
entropy_coeff = layers.reshape(entropy_coeff, shape=[1])
total_loss = (
pi_loss + vf_loss * self.vf_loss_coeff + entropy * entropy_coeff)
......
......@@ -78,7 +78,6 @@ class VTraceLoss(object):
self.entropy = layers.reduce_sum(policy_entropy)
# The summed weighted loss
entropy_coeff = layers.reshape(entropy_coeff, shape=[1])
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff +
self.entropy * entropy_coeff)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册