未验证 提交 74824fdd 编写于 作者: L Leo Chen 提交者: GitHub

add clearGradient for amp sample code (#32517)

上级 4e460d7b
...@@ -187,6 +187,7 @@ size_t VarBase::GradOpNum() const { ...@@ -187,6 +187,7 @@ size_t VarBase::GradOpNum() const {
} }
void VarBase::ClearGradient() { void VarBase::ClearGradient() {
VLOG(4) << "ClearGradient " << Name();
if (grad_var_) { if (grad_var_) {
if (grad_var_->Var().IsType<framework::SelectedRows>()) { if (grad_var_->Var().IsType<framework::SelectedRows>()) {
auto* grad_t = auto* grad_t =
......
...@@ -62,6 +62,7 @@ class GradScaler(AmpScaler): ...@@ -62,6 +62,7 @@ class GradScaler(AmpScaler):
scaled = scaler.scale(loss) # scale the loss scaled = scaler.scale(loss) # scale the loss
scaled.backward() # do backward scaled.backward() # do backward
scaler.minimize(optimizer, scaled) # update parameters scaler.minimize(optimizer, scaled) # update parameters
optimizer.clear_grad()
""" """
def __init__(self, def __init__(self,
...@@ -105,6 +106,7 @@ class GradScaler(AmpScaler): ...@@ -105,6 +106,7 @@ class GradScaler(AmpScaler):
scaled = scaler.scale(loss) # scale the loss scaled = scaler.scale(loss) # scale the loss
scaled.backward() # do backward scaled.backward() # do backward
scaler.minimize(optimizer, scaled) # update parameters scaler.minimize(optimizer, scaled) # update parameters
optimizer.clear_grad()
""" """
return super(GradScaler, self).scale(var) return super(GradScaler, self).scale(var)
...@@ -140,5 +142,6 @@ class GradScaler(AmpScaler): ...@@ -140,5 +142,6 @@ class GradScaler(AmpScaler):
scaled = scaler.scale(loss) # scale the loss scaled = scaler.scale(loss) # scale the loss
scaled.backward() # do backward scaled.backward() # do backward
scaler.minimize(optimizer, scaled) # update parameters scaler.minimize(optimizer, scaled) # update parameters
optimizer.clear_grad()
""" """
return super(GradScaler, self).minimize(optimizer, *args, **kwargs) return super(GradScaler, self).minimize(optimizer, *args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册