From f790b96d6f8fc92f9b23cf852ceb7d7f72f6e677 Mon Sep 17 00:00:00 2001 From: "Yang Yang(Tony)" Date: Wed, 13 Jun 2018 17:49:36 -0700 Subject: [PATCH] make variable->Grad() a weak_ptr (#11453) * fix #11416 * make sgd check tape has been backwarded_ * add error message --- paddle/contrib/tape/function.h | 3 ++- paddle/contrib/tape/tape.h | 2 ++ paddle/contrib/tape/variable.h | 14 +++++++------- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/paddle/contrib/tape/function.h b/paddle/contrib/tape/function.h index 0584f4ec8aa..8c9694d9a21 100644 --- a/paddle/contrib/tape/function.h +++ b/paddle/contrib/tape/function.h @@ -112,6 +112,8 @@ class SGD { } void operator()(VariableHandle input) { + PADDLE_ENFORCE(get_global_tape().HasBeenBackwarded(), + "optimization must happen after the backward"); Tape temp_tape; temp_tape.AddOp("sgd", {{"Param", {input}}, @@ -120,7 +122,6 @@ class SGD { {{"ParamOut", {input}}}, {}); temp_tape.Forward(); - input->ResetGrad(); } private: diff --git a/paddle/contrib/tape/tape.h b/paddle/contrib/tape/tape.h index 9938ce9a7f4..ed79de17a7f 100644 --- a/paddle/contrib/tape/tape.h +++ b/paddle/contrib/tape/tape.h @@ -47,6 +47,8 @@ class Tape { void Forward(); void Backward(VariableHandle target); + bool HasBeenBackwarded() { return has_been_backwarded_; } + private: bool has_been_backwarded_ = false; size_t current_position_ = 0; diff --git a/paddle/contrib/tape/variable.h b/paddle/contrib/tape/variable.h index 7e63aa38a7a..35c328e69c9 100644 --- a/paddle/contrib/tape/variable.h +++ b/paddle/contrib/tape/variable.h @@ -45,15 +45,15 @@ class Variable { void InitializeVariable(); VariableHandle Grad() { - if (grad_ == nullptr) { - grad_.reset(new Variable(desc_.Name(), true)); + if (grad_.expired()) { + VariableHandle new_grad(new Variable(desc_.Name(), true)); + grad_ = new_grad; + return new_grad; + } else { + return VariableHandle(grad_); } - - return grad_; } - void ResetGrad() { grad_ = nullptr; } - // Stochastic Gradient Descent with Momentum // VariableHandle Momentum (); @@ -79,7 +79,7 @@ class Variable { framework::VarDesc desc_; framework::Variable var_; - VariableHandle grad_; + std::weak_ptr grad_; }; } } -- GitLab