提交 cd401b92 编写于 作者: N niuyazhe

hotfix(nyz): fix ppo bug when use dual_clip and adv > 0

上级 dd4472e4
......@@ -98,7 +98,10 @@ def ppo_policy_error(data: namedtuple,
surr1 = ratio * adv
surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv
if dual_clip is not None:
policy_loss = (-torch.max(torch.min(surr1, surr2), dual_clip * adv) * weight).mean()
clip1 = torch.min(surr1, surr2)
clip2 = torch.max(clip1, dual_clip * adv)
# only use dual_clip when adv < 0
policy_loss = -(torch.where(adv < 0, clip2, clip1) * weight).mean()
else:
policy_loss = (-torch.min(surr1, surr2) * weight).mean()
with torch.no_grad():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册