提交 c3236f82 编写于 作者: X Xin Pan

polish

上级 e5d64fd4
......@@ -44,16 +44,12 @@ class Autograd {
public:
explicit Autograd(framework::Scope* scope) : scope_(scope) {}
void RunBackward(VarBase* var, framework::Variable* grad) {
if (!var->pre_op_) {
var->ApplyGrad(scope_, grad);
return;
}
void RunBackward(VarBase* var) {
PADDLE_ENFORCE(var->pre_op_->op_desc_);
// TODO(panyx0718): Only create vars that "require_grad"
std::vector<Variable*> op_grads =
CreateOpGrads(var->pre_op_->output_vars_->size());
op_grads[var->pre_op_out_idx_] = grad;
op_grads[var->pre_op_out_idx_] = var->grads_;
std::deque<std::pair<OpBase*, std::vector<Variable*>>> ready;
ready.push_back(std::make_pair(var->pre_op_, op_grads));
......@@ -238,8 +234,6 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
framework::Variable* var = scope->FindVar(outvar);
LOG(ERROR) << "apply grad " << outvar << " with origin "
<< origin_var;
// TODO(panyx0718): Accumulate.
// origin_in_var->grads_ = var;
origin_in_var->ApplyGrad(scope, var);
ret[i] = var;
// TODO(panyx0718): There might be 2 var with the same name. We
......@@ -254,15 +248,11 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
}
void VarBase::RunBackward(framework::Scope* scope) {
// TODO(panyx0718): Might not be 0th, need to detect.
grads_ = CreateVariable(pre_op_->grad_op_desc_->InputArgumentNames()[0],
grads_ = CreateVariable(framework::GradVarName(var_desc_->Name()),
var_->Get<framework::LoDTensor>().dims(), 1.0, scope,
false);
framework::Variable* grad =
CreateVariable("init@imperative_grad",
var_->Get<framework::LoDTensor>().dims(), 1.0, scope);
Autograd(scope).RunBackward(this, grad);
if (!pre_op_) return;
Autograd(scope).RunBackward(this);
}
} // namespace imperative
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册