未验证 提交 f790b96d 编写于 作者: Y Yang Yang(Tony) 提交者: GitHub

make variable->Grad() a weak_ptr (#11453)

* fix #11416

* make sgd check tape has been backwarded_

* add error message
上级 a59c3b73
...@@ -112,6 +112,8 @@ class SGD { ...@@ -112,6 +112,8 @@ class SGD {
} }
void operator()(VariableHandle input) { void operator()(VariableHandle input) {
PADDLE_ENFORCE(get_global_tape().HasBeenBackwarded(),
"optimization must happen after the backward");
Tape temp_tape; Tape temp_tape;
temp_tape.AddOp("sgd", temp_tape.AddOp("sgd",
{{"Param", {input}}, {{"Param", {input}},
...@@ -120,7 +122,6 @@ class SGD { ...@@ -120,7 +122,6 @@ class SGD {
{{"ParamOut", {input}}}, {{"ParamOut", {input}}},
{}); {});
temp_tape.Forward(); temp_tape.Forward();
input->ResetGrad();
} }
private: private:
......
...@@ -47,6 +47,8 @@ class Tape { ...@@ -47,6 +47,8 @@ class Tape {
void Forward(); void Forward();
void Backward(VariableHandle target); void Backward(VariableHandle target);
bool HasBeenBackwarded() { return has_been_backwarded_; }
private: private:
bool has_been_backwarded_ = false; bool has_been_backwarded_ = false;
size_t current_position_ = 0; size_t current_position_ = 0;
......
...@@ -45,15 +45,15 @@ class Variable { ...@@ -45,15 +45,15 @@ class Variable {
void InitializeVariable(); void InitializeVariable();
VariableHandle Grad() { VariableHandle Grad() {
if (grad_ == nullptr) { if (grad_.expired()) {
grad_.reset(new Variable(desc_.Name(), true)); VariableHandle new_grad(new Variable(desc_.Name(), true));
grad_ = new_grad;
return new_grad;
} else {
return VariableHandle(grad_);
} }
return grad_;
} }
void ResetGrad() { grad_ = nullptr; }
// Stochastic Gradient Descent with Momentum // Stochastic Gradient Descent with Momentum
// VariableHandle Momentum (); // VariableHandle Momentum ();
...@@ -79,7 +79,7 @@ class Variable { ...@@ -79,7 +79,7 @@ class Variable {
framework::VarDesc desc_; framework::VarDesc desc_;
framework::Variable var_; framework::Variable var_;
VariableHandle grad_; std::weak_ptr<Variable> grad_;
}; };
} }
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册