From 93c16d96289e2805c01fb0bc36f4eecb854fb920 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Sun, 2 Dec 2018 22:23:12 +0800 Subject: [PATCH] polish the autograd (need to verify correctness) test=develop --- paddle/fluid/imperative/layer.cc | 98 ++++++++++++-------------------- 1 file changed, 37 insertions(+), 61 deletions(-) diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index c5a8d9c6b67..2176ac78bb6 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -46,20 +46,16 @@ class Autograd { void RunBackward(VarBase* var) { PADDLE_ENFORCE(var->pre_op_->op_desc_); - // TODO(panyx0718): Only create vars that "require_grad" - std::vector op_grads = - CreateOpGrads(var->pre_op_->output_vars_->size()); - op_grads[var->pre_op_out_idx_] = var->grads_; + // TODO(panyx0718): Only create for vars that "require_grad" + (*var->pre_op_->output_vars_)[var->pre_op_out_idx_]->grads_ = var->grads_; - std::deque>> ready; - ready.push_back(std::make_pair(var->pre_op_, op_grads)); + std::deque ready; + ready.push_back(var->pre_op_); std::map dep_counts = ComputeDepCounts(var->pre_op_); - std::map> visited; while (!ready.empty()) { - OpBase* ready_op = ready.front().first; - std::vector ready_op_grads = ready.front().second; + OpBase* ready_op = ready.front(); ready.pop_front(); std::vector input_grads = ready_op->ApplyGrad(scope_); @@ -67,29 +63,12 @@ class Autograd { if (!input_grads[i]) continue; OpBase* pre_op = ready_op->pre_ops_->at(i); if (!pre_op) continue; - int pre_op_out_idx = ready_op->pre_ops_out_idx_->at(i); dep_counts[pre_op] -= 1; PADDLE_ENFORCE(dep_counts[pre_op] >= 0); bool pre_op_ready = dep_counts[pre_op] == 0; - if (pre_op_ready) { - if (visited.find(pre_op) == visited.end()) { - PADDLE_ENFORCE(pre_op->output_vars_->size() == 1); - visited[pre_op] = {input_grads[i]}; - } else { - std::vector& pre_op_grads = visited[pre_op]; - AccumGrads(pre_op_out_idx, input_grads[i], &pre_op_grads); - } - ready.push_back(std::make_pair(pre_op, visited[pre_op])); - } else { - if (visited.find(pre_op) == visited.end()) { - // TODO(panyx0718): Only create vars that "require_grad" - visited[pre_op] = CreateOpGrads(var->pre_op_->output_vars_->size()); - } else { - } - std::vector& pre_op_grads = visited[pre_op]; - AccumGrads(pre_op_out_idx, input_grads[i], &pre_op_grads); + ready.push_back(pre_op); } } } @@ -184,27 +163,22 @@ void VarBase::ApplyGrad(framework::Scope* scope, Variable* grad) { std::vector OpBase::ApplyGrad(framework::Scope* scope) { VLOG(3) << "op grad " << grad_op_desc_->Type(); - for (const std::string& invar : grad_op_desc_->InputArgumentNames()) { - block_->FindRecursiveOrCreateVar(invar); - framework::Variable* var = scope->Var(invar); - LOG(ERROR) << "op grad in var " << invar; - if (!var->IsInitialized()) { - framework::VarDesc* var_desc = block_->FindVar(invar); - if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) { - LOG(ERROR) << "grad op invar init " << invar; - var->GetMutable(); - } else { - LOG(ERROR) << "tracer doesn't support yet"; + for (const std::string& grad_invar : grad_op_desc_->InputArgumentNames()) { + if (grad_to_var_->find(grad_invar) == grad_to_var_->end()) { + continue; + } + LOG(ERROR) << "op grad in var " << grad_invar; + block_->FindRecursiveOrCreateVar(grad_invar); + framework::Variable* var = scope->Var(grad_invar); + const std::string& invar = grad_to_var_->at(grad_invar); + for (VarBase* varbase : *output_vars_) { + if (varbase->var_desc_->Name() == invar) { + var->GetMutable()->ShareDataWith( + varbase->grads_->Get()); } - } else { - var->GetMutable()->type(); } } - std::vector ret; - for (size_t i = 0; i < input_vars_->size(); ++i) { - ret.push_back(nullptr); - } for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) { LOG(ERROR) << "grad outvar " << outvar; block_->FindRecursiveOrCreateVar(outvar); @@ -225,23 +199,25 @@ std::vector OpBase::ApplyGrad(framework::Scope* scope) { opbase->Run(*scope, platform::CPUPlace()); - for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) { - if (grad_to_var_->find(outvar) != grad_to_var_->end()) { - std::string origin_var = (*grad_to_var_)[outvar]; - for (size_t i = 0; i < input_vars_->size(); ++i) { - VarBase* origin_in_var = (*input_vars_)[i]; - if (origin_in_var->var_desc_->Name() == origin_var) { - framework::Variable* var = scope->FindVar(outvar); - LOG(ERROR) << "apply grad " << outvar << " with origin " - << origin_var; - origin_in_var->ApplyGrad(scope, var); - ret[i] = var; - // TODO(panyx0718): There might be 2 var with the same name. We - // currently assume the are the same Variable*. So it doesn't matter - // which one is used. - break; - } - } + std::vector ret; + for (size_t i = 0; i < input_vars_->size(); ++i) { + bool found = false; + for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) { + Variable* var = scope->FindVar(outvar); + VarBase* origin_var = (*input_vars_)[i]; + std::string orig_var = grad_to_var_->at(outvar); + PADDLE_ENFORCE(origin_var->var_desc_->Name() == orig_var); + LOG(ERROR) << "apply grad " << outvar << " with origin " << orig_var; + origin_var->ApplyGrad(scope, var); + found = true; + ret.push_back(var); + // TODO(panyx0718): There might be another outvar with the same name. + // In that case, it doesn't matter the first one or the second one is + // used. + break; + } + if (!found) { + ret.push_back(nullptr); } } return ret; -- GitLab