提交 07822fef 编写于 作者: M minqiyang

Clear all parameters' gradient

test=develop
上级 49a7fba8
...@@ -152,12 +152,14 @@ class VarBase { ...@@ -152,12 +152,14 @@ class VarBase {
void ClearGradient() { void ClearGradient() {
VLOG(1) << "clear gradient of " << var_desc_->Name(); VLOG(1) << "clear gradient of " << var_desc_->Name();
if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>(); auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant( operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get( *(platform::DeviceContextPool::Instance().Get(
grads_->var_->Get<framework::LoDTensor>().place())), grads_->var_->Get<framework::LoDTensor>().place())),
grads_t, 0.0); grads_t, 0.0);
} }
}
framework::LoDTensor& GradValue(); framework::LoDTensor& GradValue();
......
...@@ -52,7 +52,6 @@ class Layer(core.Layer): ...@@ -52,7 +52,6 @@ class Layer(core.Layer):
def clear_gradients(self): def clear_gradients(self):
for p in self.parameters(): for p in self.parameters():
if not p._stop_gradient:
p._clear_gradient() p._clear_gradient()
def _build_once(self, inputs): def _build_once(self, inputs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册