提交 6c03542e 编写于 作者: S seatea

Fix dtype bug for loss_scale and weight_decay.

1.Change dtype of scale to dtype of grad in loss_scale.py;
2.Change dtype of weight_decay to dtype of weight in optimizer.py.
上级 930a1fb0
......@@ -84,7 +84,7 @@ apply_decay = C.MultitypeFuncGraph("apply_decay")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
return op_add((gradient, weight * F.scalar_to_array(weight_decay)))
return op_add((gradient, weight * weight_decay))
return gradient
......
......@@ -32,7 +32,7 @@ reciprocal = P.Reciprocal()
@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
return grad * F.cast(reciprocal(scale), F.dtype(grad))
class DynamicLossScaleUpdateCell(Cell):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册