提交 07822fef 编写于 作者: M minqiyang

Clear all parameters' gradient

test=develop
上级 49a7fba8
...@@ -152,11 +152,13 @@ class VarBase { ...@@ -152,11 +152,13 @@ class VarBase {
void ClearGradient() { void ClearGradient() {
VLOG(1) << "clear gradient of " << var_desc_->Name(); VLOG(1) << "clear gradient of " << var_desc_->Name();
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>(); if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
operators::math::set_constant( auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
*(platform::DeviceContextPool::Instance().Get( operators::math::set_constant(
grads_->var_->Get<framework::LoDTensor>().place())), *(platform::DeviceContextPool::Instance().Get(
grads_t, 0.0); grads_->var_->Get<framework::LoDTensor>().place())),
grads_t, 0.0);
}
} }
framework::LoDTensor& GradValue(); framework::LoDTensor& GradValue();
......
...@@ -52,8 +52,7 @@ class Layer(core.Layer): ...@@ -52,8 +52,7 @@ class Layer(core.Layer):
def clear_gradients(self): def clear_gradients(self):
for p in self.parameters(): for p in self.parameters():
if not p._stop_gradient: p._clear_gradient()
p._clear_gradient()
def _build_once(self, inputs): def _build_once(self, inputs):
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册