提交 93c16d96 编写于 作者: X Xin Pan

polish the autograd (need to verify correctness)

test=develop
上级 c3236f82
...@@ -46,20 +46,16 @@ class Autograd { ...@@ -46,20 +46,16 @@ class Autograd {
void RunBackward(VarBase* var) { void RunBackward(VarBase* var) {
PADDLE_ENFORCE(var->pre_op_->op_desc_); PADDLE_ENFORCE(var->pre_op_->op_desc_);
// TODO(panyx0718): Only create vars that "require_grad" // TODO(panyx0718): Only create for vars that "require_grad"
std::vector<Variable*> op_grads = (*var->pre_op_->output_vars_)[var->pre_op_out_idx_]->grads_ = var->grads_;
CreateOpGrads(var->pre_op_->output_vars_->size());
op_grads[var->pre_op_out_idx_] = var->grads_;
std::deque<std::pair<OpBase*, std::vector<Variable*>>> ready; std::deque<OpBase*> ready;
ready.push_back(std::make_pair(var->pre_op_, op_grads)); ready.push_back(var->pre_op_);
std::map<OpBase*, int> dep_counts = ComputeDepCounts(var->pre_op_); std::map<OpBase*, int> dep_counts = ComputeDepCounts(var->pre_op_);
std::map<OpBase*, std::vector<Variable*>> visited;
while (!ready.empty()) { while (!ready.empty()) {
OpBase* ready_op = ready.front().first; OpBase* ready_op = ready.front();
std::vector<Variable*> ready_op_grads = ready.front().second;
ready.pop_front(); ready.pop_front();
std::vector<Variable*> input_grads = ready_op->ApplyGrad(scope_); std::vector<Variable*> input_grads = ready_op->ApplyGrad(scope_);
...@@ -67,29 +63,12 @@ class Autograd { ...@@ -67,29 +63,12 @@ class Autograd {
if (!input_grads[i]) continue; if (!input_grads[i]) continue;
OpBase* pre_op = ready_op->pre_ops_->at(i); OpBase* pre_op = ready_op->pre_ops_->at(i);
if (!pre_op) continue; if (!pre_op) continue;
int pre_op_out_idx = ready_op->pre_ops_out_idx_->at(i);
dep_counts[pre_op] -= 1; dep_counts[pre_op] -= 1;
PADDLE_ENFORCE(dep_counts[pre_op] >= 0); PADDLE_ENFORCE(dep_counts[pre_op] >= 0);
bool pre_op_ready = dep_counts[pre_op] == 0; bool pre_op_ready = dep_counts[pre_op] == 0;
if (pre_op_ready) { if (pre_op_ready) {
if (visited.find(pre_op) == visited.end()) { ready.push_back(pre_op);
PADDLE_ENFORCE(pre_op->output_vars_->size() == 1);
visited[pre_op] = {input_grads[i]};
} else {
std::vector<Variable*>& pre_op_grads = visited[pre_op];
AccumGrads(pre_op_out_idx, input_grads[i], &pre_op_grads);
}
ready.push_back(std::make_pair(pre_op, visited[pre_op]));
} else {
if (visited.find(pre_op) == visited.end()) {
// TODO(panyx0718): Only create vars that "require_grad"
visited[pre_op] = CreateOpGrads(var->pre_op_->output_vars_->size());
} else {
}
std::vector<Variable*>& pre_op_grads = visited[pre_op];
AccumGrads(pre_op_out_idx, input_grads[i], &pre_op_grads);
} }
} }
} }
...@@ -184,27 +163,22 @@ void VarBase::ApplyGrad(framework::Scope* scope, Variable* grad) { ...@@ -184,27 +163,22 @@ void VarBase::ApplyGrad(framework::Scope* scope, Variable* grad) {
std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) { std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
VLOG(3) << "op grad " << grad_op_desc_->Type(); VLOG(3) << "op grad " << grad_op_desc_->Type();
for (const std::string& invar : grad_op_desc_->InputArgumentNames()) { for (const std::string& grad_invar : grad_op_desc_->InputArgumentNames()) {
block_->FindRecursiveOrCreateVar(invar); if (grad_to_var_->find(grad_invar) == grad_to_var_->end()) {
framework::Variable* var = scope->Var(invar); continue;
LOG(ERROR) << "op grad in var " << invar;
if (!var->IsInitialized()) {
framework::VarDesc* var_desc = block_->FindVar(invar);
if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) {
LOG(ERROR) << "grad op invar init " << invar;
var->GetMutable<framework::LoDTensor>();
} else {
LOG(ERROR) << "tracer doesn't support yet";
} }
} else { LOG(ERROR) << "op grad in var " << grad_invar;
var->GetMutable<framework::LoDTensor>()->type(); block_->FindRecursiveOrCreateVar(grad_invar);
framework::Variable* var = scope->Var(grad_invar);
const std::string& invar = grad_to_var_->at(grad_invar);
for (VarBase* varbase : *output_vars_) {
if (varbase->var_desc_->Name() == invar) {
var->GetMutable<framework::LoDTensor>()->ShareDataWith(
varbase->grads_->Get<framework::LoDTensor>());
} }
} }
std::vector<Variable*> ret;
for (size_t i = 0; i < input_vars_->size(); ++i) {
ret.push_back(nullptr);
} }
for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) { for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) {
LOG(ERROR) << "grad outvar " << outvar; LOG(ERROR) << "grad outvar " << outvar;
block_->FindRecursiveOrCreateVar(outvar); block_->FindRecursiveOrCreateVar(outvar);
...@@ -225,23 +199,25 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) { ...@@ -225,23 +199,25 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
opbase->Run(*scope, platform::CPUPlace()); opbase->Run(*scope, platform::CPUPlace());
for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) { std::vector<Variable*> ret;
if (grad_to_var_->find(outvar) != grad_to_var_->end()) {
std::string origin_var = (*grad_to_var_)[outvar];
for (size_t i = 0; i < input_vars_->size(); ++i) { for (size_t i = 0; i < input_vars_->size(); ++i) {
VarBase* origin_in_var = (*input_vars_)[i]; bool found = false;
if (origin_in_var->var_desc_->Name() == origin_var) { for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) {
framework::Variable* var = scope->FindVar(outvar); Variable* var = scope->FindVar(outvar);
LOG(ERROR) << "apply grad " << outvar << " with origin " VarBase* origin_var = (*input_vars_)[i];
<< origin_var; std::string orig_var = grad_to_var_->at(outvar);
origin_in_var->ApplyGrad(scope, var); PADDLE_ENFORCE(origin_var->var_desc_->Name() == orig_var);
ret[i] = var; LOG(ERROR) << "apply grad " << outvar << " with origin " << orig_var;
// TODO(panyx0718): There might be 2 var with the same name. We origin_var->ApplyGrad(scope, var);
// currently assume the are the same Variable*. So it doesn't matter found = true;
// which one is used. ret.push_back(var);
// TODO(panyx0718): There might be another outvar with the same name.
// In that case, it doesn't matter the first one or the second one is
// used.
break; break;
} }
} if (!found) {
ret.push_back(nullptr);
} }
} }
return ret; return ret;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册