From 3b70f870e27e5d14c1b1fd6a4bf8bdc2d4660060 Mon Sep 17 00:00:00 2001 From: Jiabin Yang Date: Tue, 4 Jun 2019 11:46:05 +0800 Subject: [PATCH] Using Smart pointer to optimizer memory usage of dyGraph (#17768) * for debug * test=develop, memory optimize for dygraph using shared_ptr * test=develop, fix travis ci showed error * test=develop, fix bug for recurrent usage of varbase * test=develop, init varbase when it need to be Add --- paddle/fluid/imperative/layer.cc | 182 +++++++++--------- paddle/fluid/imperative/layer.h | 71 ++++--- paddle/fluid/imperative/tracer.cc | 63 +++--- paddle/fluid/imperative/tracer.h | 13 +- paddle/fluid/imperative/type_defs.h | 12 +- paddle/fluid/pybind/imperative.cc | 28 +-- python/paddle/fluid/dygraph/tracer.py | 25 +-- .../test_imperative_recurrent_usage.py | 87 +++++++++ 8 files changed, 277 insertions(+), 204 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_imperative_recurrent_usage.py diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index e48945b8e2..4bced3a0e8 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/imperative/layer.h" +#include #include #include #include @@ -77,52 +78,63 @@ class TensorAddToFunctor : public boost::static_visitor<> { } // namespace detail -void AddTo(Variable* src, Variable* dst, platform::Place place) { - framework::Tensor* dst_tensor = dst->GetMutable(); - framework::Tensor* src_tensor = src->GetMutable(); - - // FIXME(minqiyang): loss_grad op will pass a zero grad of label - // ugly fix for it - if (src_tensor->numel() == 0) { +void AddTo(std::shared_ptr src, std::shared_ptr dst, + platform::Place place) { + if (!dst->IsInitialize()) { + VLOG(2) << "im here1"; + PADDLE_ENFORCE(src->IsInitialize(), "Using uninitialized VarBase"); + dst->var_ = std::move(src->var_); + dst->SetInitialize(true); return; - } + } else { + framework::Tensor* dst_tensor = + dst->var_->GetMutable(); + framework::Tensor* src_tensor = + src->var_->GetMutable(); + + // FIXME(minqiyang): loss_grad op will pass a zero grad of label + // ugly fix for it + if (src_tensor->numel() == 0) { + return; + } - PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(), - "dst_numel %lld vs. src_numel %lld", dst_tensor->numel(), - src_tensor->numel()); + PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(), + "dst_numel %lld vs. src_numel %lld", dst_tensor->numel(), + src_tensor->numel()); - detail::TensorAddToFunctor func( - src_tensor->numel(), src_tensor->data(), - dst_tensor->mutable_data(place)); - boost::apply_visitor(func, place); + detail::TensorAddToFunctor func( + src_tensor->numel(), src_tensor->data(), + dst_tensor->mutable_data(place)); + boost::apply_visitor(func, place); + } } -void ZeroGrads(VarBase* vb, const platform::Place& place) { +void ZeroGrads(const std::shared_ptr vb, + const platform::Place& place) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); auto grad_t = vb->var_->GetMutable(); operators::math::set_constant(*dev_ctx, grad_t, 0.0); } -void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) { - PADDLE_ENFORCE(bck_map->find(target) != bck_map->end(), +void AddGradBySort(BackwardSumMap* bck_map, + std::shared_ptr target) { + PADDLE_ENFORCE(bck_map->find(target.get()) != bck_map->end(), "Can't find %s in backward grad map", target->Name()); - std::pair>>& current = - bck_map->at(target); - std::sort( - current.second.begin(), current.second.end(), - [](const std::pair& a, const std::pair& b) { - return a.first > b.first; - }); + std::pair>>>& + current = bck_map->at(target.get()); + std::sort(current.second.begin(), current.second.end(), + [](const std::pair>& a, + const std::pair>& b) { + return a.first > b.first; + }); for (auto& var_pair : current.second) { - Variable* origin_grad = target->var_.get(); - Variable* grad_to_add = var_pair.second->var_.get(); VLOG(10) << "add origin_grad: " << target->Name(); VLOG(10) << "added grad: " << var_pair.second->Name() << " trace id is: " << var_pair.first; - AddTo(grad_to_add, origin_grad, current.first); - delete var_pair.second; - var_pair.second = nullptr; + AddTo(var_pair.second, target, current.first); + var_pair.second.reset(); } } @@ -146,24 +158,22 @@ class Autograd { while (!ready.empty()) { OpBase* ready_op = ready.front(); ready.pop_front(); - std::map> input_grads = + std::vector grads_outputs = ready_op->ApplyGrad(&bck_map, &grad_ref, bck_stratedy); - for (auto it = input_grads.rbegin(); it != input_grads.rend(); ++it) { - const std::vector& ingrads = it->second; - for (size_t i = 0; i < ingrads.size(); ++i) { - if (!ingrads[i]) continue; - auto p = ready_op->input_vars_[it->first][i]; - - if (p->IsStopGradient()) continue; - OpBase* pre_op = ready_op->pre_ops_[it->first][i]; - if (!pre_op) continue; - - dep_counts[pre_op] -= 1; - PADDLE_ENFORCE(dep_counts[pre_op] >= 0); - bool pre_op_ready = dep_counts[pre_op] == 0; - if (pre_op_ready) { - ready.push_back(pre_op); + for (const auto& map : grads_outputs) { + for (auto it = map.rbegin(); it != map.rend(); ++it) { + const std::vector>& grad_outs = it->second; + for (size_t i = 0; i < grad_outs.size(); ++i) { + if (!grad_outs[i] || grad_outs[i]->IsStopGradient()) continue; + OpBase* pre_op = grad_outs[i]->PreOp(); + if (!pre_op) continue; + dep_counts[pre_op] -= 1; + PADDLE_ENFORCE(dep_counts[pre_op] >= 0); + bool pre_op_ready = dep_counts[pre_op] == 0; + if (pre_op_ready) { + ready.push_back(pre_op); + } } } } @@ -194,7 +204,7 @@ class Autograd { for (const auto& map : candidate->grad_output_vars_) { for (const auto& it : map) { for (const auto& vb : it.second) { - ++(*grad_ref)[vb]; + ++(*grad_ref)[vb.get()]; } } } @@ -202,7 +212,7 @@ class Autograd { for (auto it : candidate->pre_ops_) { for (OpBase* pre_op : it.second) { if (!pre_op) continue; - VLOG(9) << "op dep " << candidate->Type() << " trace id " + VLOG(2) << "op dep " << candidate->Type() << " trace id " << candidate->trace_id_ << " <---- " << it.first << " <---- " << pre_op->Type() << " trace id " << pre_op->trace_id_; if (visited.find(pre_op) == visited.end()) { @@ -254,7 +264,7 @@ framework::LoDTensor& VarBase::GradValue() { return *(grads_->var_->GetMutable()); } -std::map> OpBase::ApplyGrad( +std::vector OpBase::ApplyGrad( BackwardSumMap* bck_map, GradientRef* grad_ref, const detail::BackwardStrategy& bck_stratedy) { PADDLE_ENFORCE(!grad_op_descs_.empty(), "%s has no backward implementation", @@ -274,17 +284,14 @@ std::map> OpBase::ApplyGrad( for (const auto& it : grad_output_variable_map) { auto& outputs = tmp_grad_outputs[k][it.first]; outputs.reserve(it.second.size()); - for (VarBase* origin_grad_var_base : it.second) { - if (!origin_grad_var_base->IsInitialize()) { - origin_grad_var_base->InitBuffer(); - ZeroGrads(origin_grad_var_base, place_); - } + for (const std::shared_ptr& origin_grad_var_base : + it.second) { // Allocate a new variable - VarBase* tmp_grad_var_base = new VarBase( + std::shared_ptr tmp_grad_var_base(new VarBase( string::Sprintf("%s@IGrad", origin_grad_var_base->Name()), origin_grad_var_base->DataType(), origin_grad_var_base->Dims(), - place_, true, false); - outputs.emplace_back(tmp_grad_var_base); + place_, true, false)); + outputs.emplace_back(std::move(tmp_grad_var_base)); } } @@ -298,7 +305,7 @@ std::map> OpBase::ApplyGrad( auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type()); if (info.infer_var_type_) { RuntimeInferVarTypeContext infer_var_type_ctx( - &grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_); + &grad_input_vars_[k], &tmp_grad_outputs[k], &(opbase->Attrs())); info.infer_var_type_(&infer_var_type_ctx); } @@ -313,14 +320,14 @@ std::map> OpBase::ApplyGrad( for (const auto& it : grad_input_vars_[k]) { auto& grad_invars = grad_invars_map[it.first]; grad_invars.reserve(it.second.size()); - for (VarBase* grad_inp : it.second) { + for (const std::shared_ptr& grad_inp : it.second) { PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr", grad_op_desc->Type(), grad_inp->Name()); if (!grad_inp->IsInitialize()) { grad_inp->InitBuffer(); ZeroGrads(grad_inp, place_); } - const VarBase* const_grad_inp = grad_inp; + const std::shared_ptr& const_grad_inp = grad_inp; grad_invars.emplace_back(const_grad_inp->var_.get()); } } @@ -328,7 +335,7 @@ std::map> OpBase::ApplyGrad( for (const auto& it : tmp_grad_outputs[k]) { auto& grad_outvars = grad_outvars_map[it.first]; grad_outvars.reserve(it.second.size()); - for (VarBase* grad_out : it.second) { + for (const std::shared_ptr& grad_out : it.second) { PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr", grad_op_desc->Type(), grad_out->Name()); @@ -355,56 +362,48 @@ std::map> OpBase::ApplyGrad( for (size_t i = 0; i < outputs.size(); ++i) { // track outputs used by sum if (bck_stratedy.sorted_sum_gradient_) { -#ifndef PADDLE_WITH_CUDA - VLOG(10) << "origin_outputs is : " << origin_outputs[i]->Name() - << " "; - VLOG(10) << origin_outputs[i] - ->var_->GetMutable() - ->data()[0]; - VLOG(10) << "outputs is : " << outputs[i]->Name() << " "; - VLOG(10) << outputs[i] - ->var_->GetMutable() - ->data()[0]; -#endif - if (bck_map->find(origin_outputs[i]) != bck_map->end()) { + if (bck_map->find(origin_outputs[i].get()) != bck_map->end()) { VLOG(10) << "add sub grad to " << origin_outputs[i]->Name(); - bck_map->at(origin_outputs[i]) + bck_map->at(origin_outputs[i].get()) .second.emplace_back( - std::pair(this->trace_id_, outputs[i])); + std::pair>( + this->trace_id_, std::move(outputs[i]))); } else { VLOG(10) << "insert new map for " << origin_outputs[i]->Name(); - std::pair>> - tmp(place_, {std::make_pair(this->trace_id_, outputs[i])}); - bck_map->insert(std::make_pair(origin_outputs[i], tmp)); + std::pair>>> + tmp(place_, + {std::make_pair(this->trace_id_, std::move(outputs[i]))}); + bck_map->insert(std::make_pair(origin_outputs[i].get(), tmp)); } - PADDLE_ENFORCE(grad_ref->find(origin_outputs[i]) != grad_ref->end(), - "Can't find %s in grad_reference count map", - origin_outputs[i]->Name()); - PADDLE_ENFORCE(grad_ref->at(origin_outputs[i]) >= 1, + PADDLE_ENFORCE( + grad_ref->find(origin_outputs[i].get()) != grad_ref->end(), + "Can't find %s in grad_reference count map", + origin_outputs[i]->Name()); + PADDLE_ENFORCE(grad_ref->at(origin_outputs[i].get()) >= 1, "Backward error when calculate grad reference"); - if (grad_ref->at(origin_outputs[i]) > 1) { + if (grad_ref->at(origin_outputs[i].get()) > 1) { VLOG(10) << "remove ref for " << origin_outputs[i]->Name(); - grad_ref->at(origin_outputs[i])--; + grad_ref->at(origin_outputs[i].get())--; } else { VLOG(10) << "Add grad for: " << origin_outputs[i]->Name(); AddGradBySort(bck_map, origin_outputs[i]); - grad_ref->at(origin_outputs[i])--; + grad_ref->at(origin_outputs[i].get())--; } } else { - framework::Variable* grad = outputs[i]->var_.get(); - framework::Variable* orig_grad = origin_outputs[i]->var_.get(); VLOG(10) << "AddTo Called with orig_grad is: " << origin_outputs[i]->name_ << " Grad to be added is " << outputs[i]->name_; - AddTo(grad, orig_grad, place_); - delete outputs[i]; + AddTo(outputs[i], origin_outputs[i], place_); + outputs[i].reset(); } } } } - return input_vars_; + return grad_output_vars_; } void OpBase::InvokeBackwardHooks() { @@ -434,9 +433,6 @@ void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) { var_->GetMutable()->place())), grads_t, 1.0); - PADDLE_ENFORCE( - grads_ == - pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_); Autograd().RunBackward(this, bck_stratedy); } diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 5c7cc20433..3d31001df9 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -171,32 +171,27 @@ class VarBase { if (need_initialize) { tensor->mutable_data(place, dtype); is_initialized_ = true; - VLOG(2) << "initialized varbase: " << name_ << " type: " << dtype + VLOG(8) << "initialized varbase: " << name_ << " type: " << dtype << " place: " << place; } else { is_initialized_ = false; - VLOG(2) << "not initialized varbase: " << name_; + VLOG(8) << "not initialized varbase: " << name_; } - VLOG(2) << "create varbase: " << name_ << " type: " << dtype - << " place: " << place; + VLOG(8) << "create varbase: " << name_ << " type: " << dtype + << " place: " << place << "Stop gradient: " << stop_gradient_; } public: virtual ~VarBase() { - if (grads_) { - delete grads_; - grads_ = nullptr; - } - pre_op_ = nullptr; pre_op_out_idx_ = -1; - VLOG(2) << "destruct varbase: " << name_; + VLOG(8) << "destruct varbase: " << name_; } inline void SetName(const std::string& name) { name_ = name; } inline std::string Name() const { return name_; } inline bool IsInitialize() const { return is_initialized_; } - + inline void SetInitialize(bool inited) { is_initialized_ = inited; } inline std::vector Shape() const { if (var_->IsInitialized()) { return framework::vectorize(var_->Get().dims()); @@ -214,10 +209,7 @@ class VarBase { auto tensor = var_->GetMutable(); tensor->mutable_data(tensor->place(), type); } - inline framework::proto::VarType::Type DataType() const { - auto tensor = var_->Get(); - return tensor.type(); - } + inline framework::proto::VarType::Type DataType() const { return dtype_; } // tensor type. e.g.. LoDTensor inline void SetType(framework::proto::VarType::Type type) { type_ = type; } @@ -225,11 +217,15 @@ class VarBase { inline void SetStopGradient(bool stop_gradient) { stop_gradient_ = stop_gradient; + if (grads_) { + grads_->stop_gradient_ = stop_gradient; + } } inline bool IsStopGradient() const { return stop_gradient_; } inline void SetPersistable(bool persistable) { persistable_ = persistable; } inline bool IsPersistable() const { return persistable_; } + inline void SetPreOp(OpBase* op) { pre_op_ = op; } inline platform::Place GetPlace() { return place_; } inline OpBase* PreOp() const { return pre_op_; } inline int PreOpOutIdx() const { return pre_op_out_idx_; } @@ -248,10 +244,10 @@ class VarBase { if (!is_initialized_) { var_->GetMutable()->mutable_data(place_, dtype_); is_initialized_ = true; - VLOG(2) << "initialized varbase: " << name_ << " type: " << dtype_ + VLOG(8) << "initialized varbase: " << name_ << " type: " << dtype_ << " place: " << place_; } else { - VLOG(2) << "var: " << name_ << " has already been initialized "; + VLOG(8) << "var: " << name_ << " has already been initialized "; } } @@ -290,7 +286,7 @@ class VarBase { platform::Place place_; std::unique_ptr var_; - VarBase* grads_; + std::shared_ptr grads_; private: framework::proto::VarType::Type dtype_; @@ -314,22 +310,23 @@ class PYBIND11_HIDDEN OpBase { backward_hooks_() {} virtual ~OpBase() { - // TODO(minqiyang): remove op_desc from block_desc in tracer - // - // reset all output vars' pre op - for (auto iter : output_vars_) { - for (VarBase* var : iter.second) { - var->ResetPreOp(this); + for (const auto& iter : outputs_ref) { + for (const auto& var : iter.second) { + auto vb = var.lock(); + if (vb) { + VLOG(3) << "Op reset by" << vb->name_; + vb->ResetPreOp(this); + } } } - + // TODO(minqiyang): remove op_desc from block_desc in tracer // release resource for (framework::OpDesc* desc : grad_op_descs_) { delete desc; } } - std::map> ApplyGrad( + std::vector ApplyGrad( BackwardSumMap* bck_map, GradientRef* grad_ref, const detail::BackwardStrategy& bck_stratedy); @@ -343,12 +340,13 @@ class PYBIND11_HIDDEN OpBase { void InvokeBackwardHooks(); - void TrackPreOp(const std::string& inp_name, - const std::vector& inputs) { + void TrackPreOp( + const std::string& inp_name, + const std::vector>& inputs) { auto& pre_ops_list = pre_ops_[inp_name]; pre_ops_list.reserve(inputs.size()); auto& pre_ops_out_idx_list = pre_ops_out_idx_[inp_name]; - for (VarBase* inp_var : inputs) { + for (std::shared_ptr inp_var : inputs) { if (inp_var->PreOp() && !inp_var->IsStopGradient()) { VLOG(3) << "add pre op " << inp_var->PreOp()->Type() << " in slot " << inp_name; @@ -371,11 +369,10 @@ class PYBIND11_HIDDEN OpBase { platform::Place place_; - VarBasePtrMap input_vars_; - VarBasePtrMap output_vars_; OpBasePtrMap pre_ops_; std::map> pre_ops_out_idx_; + VarBaseWeakPtrMap outputs_ref; // Inputs to a vector of bwd ops. std::vector grad_input_vars_; // Outputs to a vector of bwd ops. @@ -390,8 +387,9 @@ class Layer { public: virtual ~Layer() {} - virtual std::vector Forward(const std::vector& inputs) { - std::vector vars; + virtual std::vector> Forward( + const std::vector>& inputs) { + std::vector> vars; return vars; } }; @@ -412,7 +410,7 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext var_set_() { input_names_.reserve(inputs_->size()); for (auto& it : *inputs_) { - for (imperative::VarBase* var : it.second) { + for (std::shared_ptr var : it.second) { input_names_[it.first].emplace_back(var->Name()); var_set_[var->Name()] = var; } @@ -420,7 +418,7 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext output_names_.reserve(outputs_->size()); for (auto& it : *outputs_) { - for (imperative::VarBase* var : it.second) { + for (std::shared_ptr var : it.second) { output_names_[it.first].emplace_back(var->Name()); var_set_[var->Name()] = var; } @@ -516,7 +514,8 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext const framework::AttributeMap* attrs_; std::unordered_map> input_names_; std::unordered_map> output_names_; - std::unordered_map var_set_; + std::unordered_map> + var_set_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index b08c929aaf..bde5c6d400 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -46,23 +46,25 @@ void CreateGradOp(const framework::OpDesc& op_desc, } } -void CreateNoBuffuerGrad(VarBase* var, platform::DeviceContext* dev_ctx) { +void CreateNoBuffuerGrad(std::shared_ptr var, + platform::DeviceContext* dev_ctx) { PADDLE_ENFORCE_NOT_NULL(var, "Could not get valid var base"); PADDLE_ENFORCE_NOT_NULL(dev_ctx, "Could not get valid device from forward op"); if (var->grads_ == nullptr) { auto& var_t = var->var_->Get(); - var->grads_ = new VarBase(var->GradName(), framework::proto::VarType::FP32, - framework::vectorize(var_t.dims()), - dev_ctx->GetPlace(), true, false, false); + var->grads_ = std::shared_ptr( + new VarBase(var->GradName(), framework::proto::VarType::FP32, + framework::vectorize(var_t.dims()), dev_ctx->GetPlace(), + var->IsStopGradient(), false, false)); } } platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) { platform::Place result = place; - for (auto it : inputs) { - for (VarBase* var : it.second) { + for (const auto& it : inputs) { + for (const std::shared_ptr& var : it.second) { platform::Place tmp_place = var->var_->Get().place(); if (!platform::is_same_place(tmp_place, result)) { @@ -96,7 +98,7 @@ framework::VariableNameMap CreateInputVarNameMap( auto var_vector = it->second; std::vector args; args.reserve(var_vector.size()); - for (VarBase* var_base : var_vector) { + for (std::shared_ptr var_base : var_vector) { args.emplace_back(var_base->Name()); } result[in.name()] = args; @@ -124,7 +126,7 @@ framework::VariableNameMap CreateOutputVarNameMap( auto var_vector = it->second; std::vector args; args.reserve(var_vector.size()); - for (VarBase* var_base : var_vector) { + for (const std::shared_ptr& var_base : var_vector) { args.emplace_back(var_base->Name()); } result[out.name()] = args; @@ -135,22 +137,20 @@ framework::VariableNameMap CreateOutputVarNameMap( Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {} -std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, - VarBasePtrMap* outputs, - framework::AttributeMap attrs_map, - const platform::Place expected_place, - const bool stop_gradient) { +void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, + VarBasePtrMap* outputs, framework::AttributeMap attrs_map, + const platform::Place expected_place, + const bool stop_gradient) { platform::RecordEvent record_event(op->type_); framework::VariableValueMap invars_map; framework::VariableValueMap outvars_map; // Construct input_vars_map and output_vars_map - std::map current_vars_map; - op->input_vars_ = inputs; - for (auto it : op->input_vars_) { + std::map> current_vars_map; + for (auto it : inputs) { auto& invars = invars_map[it.first]; invars.reserve(it.second.size()); - for (VarBase* inp : it.second) { + for (std::shared_ptr inp : it.second) { PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", op->Type(), inp->Name()); @@ -165,13 +165,15 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, op->TrackPreOp(it.first, it.second); } - op->output_vars_ = *outputs; - for (auto it : op->output_vars_) { + for (const auto& it : *outputs) { auto& outvars = outvars_map[it.first]; - const std::vector& outputs = it.second; - outvars.reserve(outputs.size()); - for (size_t i = 0U; i < outputs.size(); ++i) { - VarBase* out = outputs[i]; + const std::vector>& outputs_tmp = + it.second; + outvars.reserve(outputs_tmp.size()); + for (size_t i = 0U; i < outputs_tmp.size(); ++i) { + // Add weak_ptr to track outputs + op->outputs_ref[it.first].emplace_back(outputs_tmp[i]); + std::shared_ptr out = outputs_tmp[i]; outvars.emplace_back(out->var_.get()); out->TrackPreOp(op, it.first, i, stop_gradient); if (!stop_gradient) { @@ -223,8 +225,6 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, framework::ExecutionContext(prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx, prepared_op.kernel_configs)); - // construct backward op - std::set vars_saved_for_backward; if (!stop_gradient) { VLOG(5) << "start construct backward op"; @@ -258,13 +258,13 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, // Forward inputs or outputs. grad_in_vars.emplace_back(fwd_var_it->second); } else { - VarBase* var = current_vars_map[var_it->second]; + std::shared_ptr var = + current_vars_map[var_it->second]; CreateNoBuffuerGrad(var, prepared_op.GetDeviceContext()); // Douts. + var->grads_->SetPreOp(var->PreOp()); grad_in_vars.emplace_back(var->grads_); } - - vars_saved_for_backward.insert(it.first); } } @@ -276,16 +276,17 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, "Could not found the grad op output var, should this " "operator %s's stop gradient be True", op->Type()); - VarBase* var = current_vars_map[var_it->second]; + + std::shared_ptr var = + current_vars_map[var_it->second]; CreateNoBuffuerGrad(var, prepared_op.GetDeviceContext()); + var->grads_->SetPreOp(var->PreOp()); grad_out_vars.push_back(var->grads_); VLOG(3) << "grads output var name: " << var->name_; } } } } - - return vars_saved_for_backward; } } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 9d95e3cd07..02d9022741 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -36,9 +36,6 @@ void CreateGradOp(const framework::OpDesc& op_desc, framework::OpDesc** grad_op_desc, std::unordered_map* grad_to_var); -void InitVar(const VarBase* var, framework::Variable* grad_var, - platform::DeviceContext* dev_ctx); - platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs); class Tracer { @@ -47,11 +44,11 @@ class Tracer { virtual ~Tracer() {} - std::set Trace(OpBase* op, const VarBasePtrMap& inputs, - VarBasePtrMap* outputs, // NOLINT - framework::AttributeMap attrs_map, - const platform::Place expected_place, - const bool stop_gradient = false); + void Trace(OpBase* op, const VarBasePtrMap& inputs, + VarBasePtrMap* outputs, // NOLINT + framework::AttributeMap attrs_map, + const platform::Place expected_place, + const bool stop_gradient = false); private: platform::Place GetPlace(const VarBasePtrMap& inputs); diff --git a/paddle/fluid/imperative/type_defs.h b/paddle/fluid/imperative/type_defs.h index 13d08cbb71..c22208a392 100644 --- a/paddle/fluid/imperative/type_defs.h +++ b/paddle/fluid/imperative/type_defs.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include #include @@ -26,12 +27,17 @@ namespace imperative { class VarBase; class OpBase; -typedef std::map> VarBasePtrMap; -typedef std::map> ConstVarBasePtrMap; +typedef std::map>> + VarBasePtrMap; +typedef std::map>> + VarBaseWeakPtrMap; +typedef std::map>> + ConstVarBasePtrMap; typedef std::map> OpBasePtrMap; typedef std::unordered_map< const VarBase*, - std::pair>>> + std::pair>>>> BackwardSumMap; // var_grad -> {place, {id -> var_grad@rename}} typedef std::unordered_map GradientRef; diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 5110d5e40d..31156ab1c9 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -35,9 +35,11 @@ class Layer : public imperative::Layer { public: using imperative::Layer::Layer; // Inherit constructors - std::vector Forward( - const std::vector &inputs) override { - PYBIND11_OVERLOAD(std::vector, Layer, Forward, + std::vector> Forward( + const std::vector> &inputs) + override { + PYBIND11_OVERLOAD(std::vector>, Layer, + Forward, inputs); // NOLINT } }; @@ -72,7 +74,8 @@ void BindImperative(pybind11::module *m_ptr) { m.def("stop_imperative_gperf_profiler", []() { imperative::StopProfile(); }); - py::class_(m, "VarBase", R"DOC()DOC") + py::class_>( + m, "VarBase", R"DOC()DOC") .def( py::init, const paddle::platform::CPUPlace, @@ -136,10 +139,11 @@ void BindImperative(pybind11::module *m_ptr) { py::class_ layer(m, "Layer"); layer.def(py::init<>()) - .def("forward", [](imperative::Layer &self, - const std::vector &inputs) { - return self.Forward(inputs); - }); + .def("forward", + [](imperative::Layer &self, + const std::vector> &inputs) { + return self.Forward(inputs); + }); py::class_(*m, "Tracer", "") .def("__init__", @@ -154,8 +158,8 @@ void BindImperative(pybind11::module *m_ptr) { const platform::CPUPlace expected_place, const bool stop_gradient = false) { py::gil_scoped_release release; - return self.Trace(op, inputs, outputs, attrs_map, expected_place, - stop_gradient); + self.Trace(op, inputs, outputs, attrs_map, expected_place, + stop_gradient); }) .def("trace", [](imperative::Tracer &self, imperative::OpBase *op, const imperative::VarBasePtrMap &inputs, @@ -164,8 +168,8 @@ void BindImperative(pybind11::module *m_ptr) { const platform::CUDAPlace expected_place, const bool stop_gradient = false) { py::gil_scoped_release release; - return self.Trace(op, inputs, outputs, attrs_map, expected_place, - stop_gradient); + self.Trace(op, inputs, outputs, attrs_map, expected_place, + stop_gradient); }); // define parallel context diff --git a/python/paddle/fluid/dygraph/tracer.py b/python/paddle/fluid/dygraph/tracer.py index 9209245814..c802e31115 100644 --- a/python/paddle/fluid/dygraph/tracer.py +++ b/python/paddle/fluid/dygraph/tracer.py @@ -24,9 +24,7 @@ __all__ = ['Tracer'] def release_op(op): - del framework._dygraph_tracer()._ops[op._trace_id].inputs - del framework._dygraph_tracer()._ops[op._trace_id].outputs - del framework._dygraph_tracer()._ops[op._trace_id].backward_refs + del framework._dygraph_tracer()._ops[op._trace_id] class Tracer(core.Tracer): @@ -55,7 +53,6 @@ class Tracer(core.Tracer): def trace_op(self, op, inputs, outputs, stop_gradient=False): # TODO(hy): previous version will cause memory failed - op.inputs = inputs inps = defaultdict(list) for k, vars in six.iteritems(inputs): if isinstance(vars, framework.Variable): @@ -64,7 +61,6 @@ class Tracer(core.Tracer): for var in vars: inps[k].append(var._ivar) - op.outputs = outputs outs = defaultdict(list) for k, vars in six.iteritems(outputs): if isinstance(vars, framework.Variable): @@ -76,28 +72,15 @@ class Tracer(core.Tracer): # record op's trace id op.iop._trace_id = self._trace_id - backward_refs = self.trace(op.iop, inps, outs, op.attrs, - framework._current_expected_place(), - stop_gradient) + self.trace(op.iop, inps, outs, op.attrs, + framework._current_expected_place(), stop_gradient) if not stop_gradient and self._train_mode: self._trace_id += 1 self._ops[op.iop._trace_id] = op # register backward hooks and variables if needed - if len(backward_refs) > 0: - op.iop.register_backward_hooks(release_op) - - # TODO(minqiyang): remove all inputs and outputs after separate - # var and grad - op.backward_refs = defaultdict(list) - for k, v in six.iteritems(inputs): - if k in backward_refs: - op.backward_refs[k] = inputs[k] - - for k, v in six.iteritems(outputs): - if k in backward_refs: - op.backward_refs[k] = outputs[k] + op.iop.register_backward_hooks(release_op) def train_mode(self): self._train_mode = True diff --git a/python/paddle/fluid/tests/unittests/test_imperative_recurrent_usage.py b/python/paddle/fluid/tests/unittests/test_imperative_recurrent_usage.py new file mode 100644 index 0000000000..650c2482f8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_recurrent_usage.py @@ -0,0 +1,87 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.dygraph.nn import Embedding +import paddle.fluid.framework as framework +from paddle.fluid.optimizer import SGDOptimizer +from paddle.fluid.dygraph.base import to_variable +from test_imperative_base import new_program_scope +import numpy as np +import six + + +class RecurrentTest(fluid.Layer): + def __init__(self, name_scope): + super(RecurrentTest, self).__init__(name_scope) + + def forward(self, in1, in2): + out = fluid.layers.mul(in1, in2) + sum_out = fluid.layers.reduce_sum(out) + return sum_out, out + + +class TestRecurrentFeed(unittest.TestCase): + def test_recurrent_feed(self): + + seed = 90 + original_np1 = np.arange(1, 5).reshape(2, 2).astype("float32") + original_np2 = np.arange(5, 9).reshape(2, 2).astype("float32") + with fluid.dygraph.guard(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + original_in1 = to_variable(original_np1) + original_in2 = to_variable(original_np2) + rt = RecurrentTest("RecurrentTest") + + for i in range(3): + sum_out, out = rt(original_in1, original_in2) + original_in1 = out + sum_out_value = sum_out.numpy() + sum_out.backward() + rt.clear_gradients() + + with new_program_scope(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + in1 = fluid.layers.data( + name="inp1", shape=[2, 2], append_batch_size=False) + in2 = fluid.layers.data( + name="inp2", shape=[2, 2], append_batch_size=False) + rt1 = RecurrentTest("RecurrentTest") + static_sum_out, static_out = rt1(in1, in2) + fluid.backward.append_backward(static_sum_out) + exe = fluid.Executor(fluid.CPUPlace( + ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) + + fetch_list = [static_sum_out, static_out] + for i in range(3): + out = exe.run( + fluid.default_main_program(), + feed={"inp1": original_np1, + "inp2": original_np2}, + fetch_list=fetch_list) + static_out_value = out[1] + static_sum_out = out[0] + original_np1 = static_out_value + + self.assertTrue(np.array_equal(static_sum_out, sum_out_value)) + + +if __name__ == '__main__': + unittest.main() -- GitLab