提交 07822fef 编写于 作者: M minqiyang

Clear all parameters' gradient

test=develop
上级 49a7fba8
......@@ -152,12 +152,14 @@ class VarBase {
void ClearGradient() {
VLOG(1) << "clear gradient of " << var_desc_->Name();
if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
grads_->var_->Get<framework::LoDTensor>().place())),
grads_t, 0.0);
}
}
framework::LoDTensor& GradValue();
......
......@@ -52,7 +52,6 @@ class Layer(core.Layer):
def clear_gradients(self):
for p in self.parameters():
if not p._stop_gradient:
p._clear_gradient()
def _build_once(self, inputs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册