From c3236f82d698a017e4dffa6d2a0c912a594a43f8 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Sun, 2 Dec 2018 17:11:58 +0800 Subject: [PATCH] polish --- paddle/fluid/imperative/layer.cc | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 08379a7ed5..c5a8d9c6b6 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -44,16 +44,12 @@ class Autograd { public: explicit Autograd(framework::Scope* scope) : scope_(scope) {} - void RunBackward(VarBase* var, framework::Variable* grad) { - if (!var->pre_op_) { - var->ApplyGrad(scope_, grad); - return; - } + void RunBackward(VarBase* var) { PADDLE_ENFORCE(var->pre_op_->op_desc_); // TODO(panyx0718): Only create vars that "require_grad" std::vector op_grads = CreateOpGrads(var->pre_op_->output_vars_->size()); - op_grads[var->pre_op_out_idx_] = grad; + op_grads[var->pre_op_out_idx_] = var->grads_; std::deque>> ready; ready.push_back(std::make_pair(var->pre_op_, op_grads)); @@ -238,8 +234,6 @@ std::vector OpBase::ApplyGrad(framework::Scope* scope) { framework::Variable* var = scope->FindVar(outvar); LOG(ERROR) << "apply grad " << outvar << " with origin " << origin_var; - // TODO(panyx0718): Accumulate. - // origin_in_var->grads_ = var; origin_in_var->ApplyGrad(scope, var); ret[i] = var; // TODO(panyx0718): There might be 2 var with the same name. We @@ -254,15 +248,11 @@ std::vector OpBase::ApplyGrad(framework::Scope* scope) { } void VarBase::RunBackward(framework::Scope* scope) { - // TODO(panyx0718): Might not be 0th, need to detect. - grads_ = CreateVariable(pre_op_->grad_op_desc_->InputArgumentNames()[0], + grads_ = CreateVariable(framework::GradVarName(var_desc_->Name()), var_->Get().dims(), 1.0, scope, false); - framework::Variable* grad = - CreateVariable("init@imperative_grad", - var_->Get().dims(), 1.0, scope); - - Autograd(scope).RunBackward(this, grad); + if (!pre_op_) return; + Autograd(scope).RunBackward(this); } } // namespace imperative -- GitLab