提交 c3236f82 编写于 作者: X Xin Pan

polish

上级 e5d64fd4
...@@ -44,16 +44,12 @@ class Autograd { ...@@ -44,16 +44,12 @@ class Autograd {
public: public:
explicit Autograd(framework::Scope* scope) : scope_(scope) {} explicit Autograd(framework::Scope* scope) : scope_(scope) {}
void RunBackward(VarBase* var, framework::Variable* grad) { void RunBackward(VarBase* var) {
if (!var->pre_op_) {
var->ApplyGrad(scope_, grad);
return;
}
PADDLE_ENFORCE(var->pre_op_->op_desc_); PADDLE_ENFORCE(var->pre_op_->op_desc_);
// TODO(panyx0718): Only create vars that "require_grad" // TODO(panyx0718): Only create vars that "require_grad"
std::vector<Variable*> op_grads = std::vector<Variable*> op_grads =
CreateOpGrads(var->pre_op_->output_vars_->size()); CreateOpGrads(var->pre_op_->output_vars_->size());
op_grads[var->pre_op_out_idx_] = grad; op_grads[var->pre_op_out_idx_] = var->grads_;
std::deque<std::pair<OpBase*, std::vector<Variable*>>> ready; std::deque<std::pair<OpBase*, std::vector<Variable*>>> ready;
ready.push_back(std::make_pair(var->pre_op_, op_grads)); ready.push_back(std::make_pair(var->pre_op_, op_grads));
...@@ -238,8 +234,6 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) { ...@@ -238,8 +234,6 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
framework::Variable* var = scope->FindVar(outvar); framework::Variable* var = scope->FindVar(outvar);
LOG(ERROR) << "apply grad " << outvar << " with origin " LOG(ERROR) << "apply grad " << outvar << " with origin "
<< origin_var; << origin_var;
// TODO(panyx0718): Accumulate.
// origin_in_var->grads_ = var;
origin_in_var->ApplyGrad(scope, var); origin_in_var->ApplyGrad(scope, var);
ret[i] = var; ret[i] = var;
// TODO(panyx0718): There might be 2 var with the same name. We // TODO(panyx0718): There might be 2 var with the same name. We
...@@ -254,15 +248,11 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) { ...@@ -254,15 +248,11 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
} }
void VarBase::RunBackward(framework::Scope* scope) { void VarBase::RunBackward(framework::Scope* scope) {
// TODO(panyx0718): Might not be 0th, need to detect. grads_ = CreateVariable(framework::GradVarName(var_desc_->Name()),
grads_ = CreateVariable(pre_op_->grad_op_desc_->InputArgumentNames()[0],
var_->Get<framework::LoDTensor>().dims(), 1.0, scope, var_->Get<framework::LoDTensor>().dims(), 1.0, scope,
false); false);
framework::Variable* grad = if (!pre_op_) return;
CreateVariable("init@imperative_grad", Autograd(scope).RunBackward(this);
var_->Get<framework::LoDTensor>().dims(), 1.0, scope);
Autograd(scope).RunBackward(this, grad);
} }
} // namespace imperative } // namespace imperative
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册