From 3d3f5506d261fe31b8217b5e5d59d14b48346216 Mon Sep 17 00:00:00 2001 From: Jiabin Yang Date: Wed, 5 Jun 2019 12:50:41 +0800 Subject: [PATCH] Feature/Fix recurrent usage of Varbase in Dygraph (#17838) * for debug * test=develop, memory optimize for dygraph using shared_ptr * test=develop, fix travis ci showed error * test=develop, fix bug for recurrent usage of varbase * test=develop, init varbase when it need to be Add * test=develop, fix problem of recurrent gradient * test=develop, add gradient test for recurrent varbase usage --- paddle/fluid/imperative/layer.cc | 44 +++++++++++-------- paddle/fluid/imperative/layer.h | 15 +++---- paddle/fluid/imperative/tracer.cc | 2 +- paddle/fluid/imperative/type_defs.h | 8 ++-- .../test_imperative_recurrent_usage.py | 7 ++- 5 files changed, 43 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 4bced3a0e83..27463c0470a 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -79,12 +79,16 @@ class TensorAddToFunctor : public boost::static_visitor<> { } // namespace detail void AddTo(std::shared_ptr src, std::shared_ptr dst, - platform::Place place) { - if (!dst->IsInitialize()) { - VLOG(2) << "im here1"; + platform::Place place, GradientRef* grad_ref) { + PADDLE_ENFORCE(grad_ref->find(dst.get()) != grad_ref->end(), + "gradient %s are not found in grad_ref", dst->Name()); + if ((*grad_ref)[dst.get()].second) { PADDLE_ENFORCE(src->IsInitialize(), "Using uninitialized VarBase"); dst->var_ = std::move(src->var_); - dst->SetInitialize(true); + (*grad_ref)[dst.get()].second = false; + if (!dst->IsInitialize()) { + dst->SetInitialize(true); + } return; } else { framework::Tensor* dst_tensor = @@ -118,7 +122,8 @@ void ZeroGrads(const std::shared_ptr vb, } void AddGradBySort(BackwardSumMap* bck_map, - std::shared_ptr target) { + std::shared_ptr target, + GradientRef* grad_ref) { PADDLE_ENFORCE(bck_map->find(target.get()) != bck_map->end(), "Can't find %s in backward grad map", target->Name()); std::pairName(); VLOG(10) << "added grad: " << var_pair.second->Name() << " trace id is: " << var_pair.first; - AddTo(var_pair.second, target, current.first); + AddTo(var_pair.second, target, current.first, grad_ref); var_pair.second.reset(); } } @@ -148,7 +153,6 @@ class Autograd { } VLOG(2) << "start autograd"; BackwardSumMap bck_map; - GradientRef grad_ref; std::deque ready; ready.push_back(var->PreOp()); @@ -200,12 +204,14 @@ class Autograd { while (!queue.empty()) { OpBase* candidate = queue.front(); queue.pop_front(); - if (bck_stratedy.sorted_sum_gradient_) { - for (const auto& map : candidate->grad_output_vars_) { - for (const auto& it : map) { - for (const auto& vb : it.second) { - ++(*grad_ref)[vb.get()]; + for (const auto& map : candidate->grad_output_vars_) { + for (const auto& it : map) { + for (const auto& vb : it.second) { + if (bck_stratedy.sorted_sum_gradient_) { + ++(*grad_ref)[vb.get()].first; } + // init the state of the grad_ + (*grad_ref)[vb.get()].second = true; } } } @@ -225,6 +231,8 @@ class Autograd { } return ret; } + + GradientRef grad_ref; }; std::unique_ptr VarBase::NewVarBase(const platform::Place& dst_place, @@ -382,21 +390,21 @@ std::vector OpBase::ApplyGrad( grad_ref->find(origin_outputs[i].get()) != grad_ref->end(), "Can't find %s in grad_reference count map", origin_outputs[i]->Name()); - PADDLE_ENFORCE(grad_ref->at(origin_outputs[i].get()) >= 1, + PADDLE_ENFORCE(grad_ref->at(origin_outputs[i].get()).first >= 1, "Backward error when calculate grad reference"); - if (grad_ref->at(origin_outputs[i].get()) > 1) { + if (grad_ref->at(origin_outputs[i].get()).first > 1) { VLOG(10) << "remove ref for " << origin_outputs[i]->Name(); - grad_ref->at(origin_outputs[i].get())--; + grad_ref->at(origin_outputs[i].get()).first--; } else { VLOG(10) << "Add grad for: " << origin_outputs[i]->Name(); - AddGradBySort(bck_map, origin_outputs[i]); - grad_ref->at(origin_outputs[i].get())--; + AddGradBySort(bck_map, origin_outputs[i], grad_ref); + grad_ref->at(origin_outputs[i].get()).first--; } } else { VLOG(10) << "AddTo Called with orig_grad is: " << origin_outputs[i]->name_ << " Grad to be added is " << outputs[i]->name_; - AddTo(outputs[i], origin_outputs[i], place_); + AddTo(outputs[i], origin_outputs[i], place_, grad_ref); outputs[i].reset(); } } diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 3d31001df9e..d0d02f0f424 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -166,6 +166,7 @@ class VarBase { if (!var_) { var_.reset(new framework::Variable()); } + auto tensor = var_->GetMutable(); tensor->Resize(shape); if (need_initialize) { @@ -310,13 +311,11 @@ class PYBIND11_HIDDEN OpBase { backward_hooks_() {} virtual ~OpBase() { - for (const auto& iter : outputs_ref) { - for (const auto& var : iter.second) { - auto vb = var.lock(); - if (vb) { - VLOG(3) << "Op reset by" << vb->name_; - vb->ResetPreOp(this); - } + for (const auto& it : outputs_ref) { + auto vb = it.lock(); + if (vb) { + VLOG(3) << "Op reset by" << vb->name_; + vb->ResetPreOp(this); } } // TODO(minqiyang): remove op_desc from block_desc in tracer @@ -372,7 +371,7 @@ class PYBIND11_HIDDEN OpBase { OpBasePtrMap pre_ops_; std::map> pre_ops_out_idx_; - VarBaseWeakPtrMap outputs_ref; + VarBaseWeakPtrList outputs_ref; // Inputs to a vector of bwd ops. std::vector grad_input_vars_; // Outputs to a vector of bwd ops. diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index bde5c6d4002..682bea7d09b 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -172,7 +172,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, outvars.reserve(outputs_tmp.size()); for (size_t i = 0U; i < outputs_tmp.size(); ++i) { // Add weak_ptr to track outputs - op->outputs_ref[it.first].emplace_back(outputs_tmp[i]); + op->outputs_ref.emplace_back(outputs_tmp[i]); std::shared_ptr out = outputs_tmp[i]; outvars.emplace_back(out->var_.get()); out->TrackPreOp(op, it.first, i, stop_gradient); diff --git a/paddle/fluid/imperative/type_defs.h b/paddle/fluid/imperative/type_defs.h index c22208a392b..fab8c2e6b91 100644 --- a/paddle/fluid/imperative/type_defs.h +++ b/paddle/fluid/imperative/type_defs.h @@ -29,17 +29,15 @@ class OpBase; typedef std::map>> VarBasePtrMap; -typedef std::map>> - VarBaseWeakPtrMap; -typedef std::map>> - ConstVarBasePtrMap; +typedef std::vector> VarBaseWeakPtrList; typedef std::map> OpBasePtrMap; typedef std::unordered_map< const VarBase*, std::pair>>>> BackwardSumMap; // var_grad -> {place, {id -> var_grad@rename}} -typedef std::unordered_map GradientRef; +typedef std::unordered_map> GradientRef; +// var_grad -> {ref_times, is_first_to_be_accumulate} } // namespace imperative } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_imperative_recurrent_usage.py b/python/paddle/fluid/tests/unittests/test_imperative_recurrent_usage.py index 650c2482f88..08c0fc8f001 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_recurrent_usage.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_recurrent_usage.py @@ -54,6 +54,7 @@ class TestRecurrentFeed(unittest.TestCase): original_in1 = out sum_out_value = sum_out.numpy() sum_out.backward() + dyout = out.gradient() rt.clear_gradients() with new_program_scope(): @@ -69,7 +70,9 @@ class TestRecurrentFeed(unittest.TestCase): exe = fluid.Executor(fluid.CPUPlace( ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) - fetch_list = [static_sum_out, static_out] + static_dout = fluid.default_main_program().block( + 0)._find_var_recursive(static_out.name + "@GRAD") + fetch_list = [static_sum_out, static_out, static_dout] for i in range(3): out = exe.run( fluid.default_main_program(), @@ -78,9 +81,11 @@ class TestRecurrentFeed(unittest.TestCase): fetch_list=fetch_list) static_out_value = out[1] static_sum_out = out[0] + static_dout = out[2] original_np1 = static_out_value self.assertTrue(np.array_equal(static_sum_out, sum_out_value)) + self.assertTrue(np.array_equal(static_dout, dyout)) if __name__ == '__main__': -- GitLab