From 07822fef2c692dd884abb7aa54b416a70409bb9c Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 28 Jan 2019 18:43:51 +0800 Subject: [PATCH] Clear all parameters' gradient test=develop --- paddle/fluid/imperative/layer.h | 12 +++++++----- python/paddle/fluid/imperative/layers.py | 3 +-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 46107341a4e..78205486c55 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -152,11 +152,13 @@ class VarBase { void ClearGradient() { VLOG(1) << "clear gradient of " << var_desc_->Name(); - auto grads_t = grads_->var_->GetMutable(); - operators::math::set_constant( - *(platform::DeviceContextPool::Instance().Get( - grads_->var_->Get().place())), - grads_t, 0.0); + if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) { + auto grads_t = grads_->var_->GetMutable(); + operators::math::set_constant( + *(platform::DeviceContextPool::Instance().Get( + grads_->var_->Get().place())), + grads_t, 0.0); + } } framework::LoDTensor& GradValue(); diff --git a/python/paddle/fluid/imperative/layers.py b/python/paddle/fluid/imperative/layers.py index c338c65a76b..71ff95bdea3 100644 --- a/python/paddle/fluid/imperative/layers.py +++ b/python/paddle/fluid/imperative/layers.py @@ -52,8 +52,7 @@ class Layer(core.Layer): def clear_gradients(self): for p in self.parameters(): - if not p._stop_gradient: - p._clear_gradient() + p._clear_gradient() def _build_once(self, inputs): pass -- GitLab