From a15a3fc314c9b683dcc346ffd5343f3e6c7ff1ce Mon Sep 17 00:00:00 2001 From: minqiyang Date: Sat, 23 Feb 2019 23:51:34 +0800 Subject: [PATCH] Polish code test=develop --- paddle/fluid/framework/block_desc.cc | 2 +- paddle/fluid/framework/block_desc.h | 2 +- paddle/fluid/imperative/layer.cc | 27 --------------------------- paddle/fluid/imperative/layer.h | 27 +++++++++++++++++++++++++-- paddle/fluid/imperative/tracer.cc | 6 +++--- paddle/fluid/pybind/protobuf.cc | 3 +-- 6 files changed, 31 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index 174c77a69b9..f4bb2f3e2fc 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -163,7 +163,7 @@ std::vector BlockDesc::AllOps() const { return res; } -void BlockDesc::ClearBlock() { +void BlockDesc::Clear() { // clear all ops ops_.clear(); diff --git a/paddle/fluid/framework/block_desc.h b/paddle/fluid/framework/block_desc.h index 651841daea4..e192624a261 100644 --- a/paddle/fluid/framework/block_desc.h +++ b/paddle/fluid/framework/block_desc.h @@ -97,7 +97,7 @@ class BlockDesc { std::vector AllOps() const; - void ClearBlock(); + void Clear(); size_t OpSize() const { return ops_.size(); } diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index fd1b64ee8be..9e627f594dc 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -205,33 +205,6 @@ framework::LoDTensor& VarBase::GradValue() { return *(grads_->var_->GetMutable()); } -void VarBase::ClearGradient() { - VLOG(1) << "clear gradient of " << var_desc_->Name(); - if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) { - auto grads_t = grads_->var_->GetMutable(); - operators::math::set_constant( - *(platform::DeviceContextPool::Instance().Get( - grads_->var_->Get().place())), - grads_t, 0.0); - } -} - -void VarBase::RunBackward() { - if (!pre_op_) return; - - VLOG(3) << "start backward"; - auto grads_t = grads_->var_->GetMutable(); - operators::math::set_constant( - *(platform::DeviceContextPool::Instance().Get( - var_->GetMutable()->place())), - grads_t, 1.0); - - PADDLE_ENFORCE( - grads_ == - pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_); - Autograd().RunBackward(this); -} - std::map> OpBase::ApplyGrad() { if (grad_op_descs_.empty() && backward_id_ <= 0) { VLOG(3) << "op with no grad: " << op_desc_->Type(); diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 0ebc3c9a7d2..10e2bb40826 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -150,9 +150,32 @@ class VarBase { } } - void RunBackward(); + void RunBackward() { + if (!pre_op_) return; - void ClearGradient(); + VLOG(3) << "start backward"; + auto grads_t = grads_->var_->GetMutable(); + operators::math::set_constant( + *(platform::DeviceContextPool::Instance().Get( + var_->GetMutable()->place())), + grads_t, 1.0); + + PADDLE_ENFORCE( + grads_ == + pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_); + Autograd().RunBackward(this); + } + + void ClearGradient() { + VLOG(1) << "clear gradient of " << var_desc_->Name(); + if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) { + auto grads_t = grads_->var_->GetMutable(); + operators::math::set_constant( + *(platform::DeviceContextPool::Instance().Get( + grads_->var_->Get().place())), + grads_t, 0.0); + } + } framework::LoDTensor& GradValue(); diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index f9f8d04db21..fd9e61d7c25 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -145,7 +145,7 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, prepared_op.func(framework::ExecutionContext( prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx)); - std::set grad_deps_var; + std::set vars_saved_for_backward; if (!stop_gradient) { std::unique_ptr> grad_to_var( @@ -166,7 +166,7 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, PADDLE_ENFORCE(fwd_var_it != vars.end()); // Forward inputs or outputs. grad_in_vars.push_back(fwd_var_it->second->var_); - grad_deps_var.insert(it.first); + vars_saved_for_backward.insert(it.first); } else { VarBase* var = vars[var_it->second]; if (!var->grads_->var_->IsInitialized()) { @@ -200,7 +200,7 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, } op->block_ = block; - return grad_deps_var; + return vars_saved_for_backward; } std::vector Tracer::PyTrace(OpBase* op, diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 6bfee48af83..48fe445b7d0 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -189,8 +189,7 @@ void BindBlockDesc(pybind11::module *m) { return self.HasVar(name); }, pybind11::return_value_policy::reference) - .def("_clear_block", - [](pd::BlockDesc &self) { return self.ClearBlock(); }, + .def("_clear_block", [](pd::BlockDesc &self) { return self.Clear(); }, pybind11::return_value_policy::reference) .def("_rename_var", [](pd::BlockDesc &self, const pybind11::bytes &byte_name, -- GitLab