From e9fdf9090d9c6c4f5453c671db6951076d7b3ad0 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 25 Feb 2019 11:44:49 +0800 Subject: [PATCH] Polish code test=develop --- paddle/fluid/imperative/layer.cc | 16 ++++++++++++++++ paddle/fluid/imperative/layer.h | 18 ++---------------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 9e627f594dc..8f20f0c06e0 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -271,6 +271,22 @@ std::map> OpBase::ApplyGrad() { return input_vars_; } +void VarBase::RunBackward() { + if (!pre_op_) return; + + VLOG(3) << "start backward"; + auto grads_t = grads_->var_->GetMutable(); + operators::math::set_constant( + *(platform::DeviceContextPool::Instance().Get( + var_->GetMutable()->place())), + grads_t, 1.0); + + PADDLE_ENFORCE( + grads_ == + pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_); + Autograd().RunBackward(this); +} + void PyLayer::RegisterFunc(int func_id, const py::object& py_func) { py_funcs_[func_id] = py_func; } diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 10e2bb40826..9adc81f04dd 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -140,6 +140,8 @@ class VarBase { } inline bool IsStopGradient() const { return stop_gradient_; } + void RunBackward(); + void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name, int pre_op_out_idx, bool pre_op_stop_gradient) { pre_op_ = pre_op; @@ -150,22 +152,6 @@ class VarBase { } } - void RunBackward() { - if (!pre_op_) return; - - VLOG(3) << "start backward"; - auto grads_t = grads_->var_->GetMutable(); - operators::math::set_constant( - *(platform::DeviceContextPool::Instance().Get( - var_->GetMutable()->place())), - grads_t, 1.0); - - PADDLE_ENFORCE( - grads_ == - pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_); - Autograd().RunBackward(this); - } - void ClearGradient() { VLOG(1) << "clear gradient of " << var_desc_->Name(); if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) { -- GitLab