未验证 提交 74824fdd 编写于 作者: L Leo Chen 提交者: GitHub

add clearGradient for amp sample code (#32517)

上级 4e460d7b
......@@ -187,6 +187,7 @@ size_t VarBase::GradOpNum() const {
}
void VarBase::ClearGradient() {
VLOG(4) << "ClearGradient " << Name();
if (grad_var_) {
if (grad_var_->Var().IsType<framework::SelectedRows>()) {
auto* grad_t =
......
......@@ -62,6 +62,7 @@ class GradScaler(AmpScaler):
scaled = scaler.scale(loss) # scale the loss
scaled.backward() # do backward
scaler.minimize(optimizer, scaled) # update parameters
optimizer.clear_grad()
"""
def __init__(self,
......@@ -105,6 +106,7 @@ class GradScaler(AmpScaler):
scaled = scaler.scale(loss) # scale the loss
scaled.backward() # do backward
scaler.minimize(optimizer, scaled) # update parameters
optimizer.clear_grad()
"""
return super(GradScaler, self).scale(var)
......@@ -140,5 +142,6 @@ class GradScaler(AmpScaler):
scaled = scaler.scale(loss) # scale the loss
scaled.backward() # do backward
scaler.minimize(optimizer, scaled) # update parameters
optimizer.clear_grad()
"""
return super(GradScaler, self).minimize(optimizer, *args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册