未验证 提交 3d3f5506 编写于 作者: J Jiabin Yang 提交者: GitHub

Feature/Fix recurrent usage of Varbase in Dygraph (#17838)

* for debug

* test=develop, memory optimize for dygraph using shared_ptr

* test=develop, fix travis ci showed error

* test=develop, fix bug for recurrent usage of varbase

* test=develop, init varbase when it need to be Add

* test=develop, fix problem of recurrent gradient

* test=develop, add gradient test for recurrent varbase usage
上级 60094207
...@@ -79,12 +79,16 @@ class TensorAddToFunctor : public boost::static_visitor<> { ...@@ -79,12 +79,16 @@ class TensorAddToFunctor : public boost::static_visitor<> {
} // namespace detail } // namespace detail
void AddTo(std::shared_ptr<VarBase> src, std::shared_ptr<VarBase> dst, void AddTo(std::shared_ptr<VarBase> src, std::shared_ptr<VarBase> dst,
platform::Place place) { platform::Place place, GradientRef* grad_ref) {
if (!dst->IsInitialize()) { PADDLE_ENFORCE(grad_ref->find(dst.get()) != grad_ref->end(),
VLOG(2) << "im here1"; "gradient %s are not found in grad_ref", dst->Name());
if ((*grad_ref)[dst.get()].second) {
PADDLE_ENFORCE(src->IsInitialize(), "Using uninitialized VarBase"); PADDLE_ENFORCE(src->IsInitialize(), "Using uninitialized VarBase");
dst->var_ = std::move(src->var_); dst->var_ = std::move(src->var_);
dst->SetInitialize(true); (*grad_ref)[dst.get()].second = false;
if (!dst->IsInitialize()) {
dst->SetInitialize(true);
}
return; return;
} else { } else {
framework::Tensor* dst_tensor = framework::Tensor* dst_tensor =
...@@ -118,7 +122,8 @@ void ZeroGrads(const std::shared_ptr<imperative::VarBase> vb, ...@@ -118,7 +122,8 @@ void ZeroGrads(const std::shared_ptr<imperative::VarBase> vb,
} }
void AddGradBySort(BackwardSumMap* bck_map, void AddGradBySort(BackwardSumMap* bck_map,
std::shared_ptr<imperative::VarBase> target) { std::shared_ptr<imperative::VarBase> target,
GradientRef* grad_ref) {
PADDLE_ENFORCE(bck_map->find(target.get()) != bck_map->end(), PADDLE_ENFORCE(bck_map->find(target.get()) != bck_map->end(),
"Can't find %s in backward grad map", target->Name()); "Can't find %s in backward grad map", target->Name());
std::pair<platform::Place, std::pair<platform::Place,
...@@ -133,7 +138,7 @@ void AddGradBySort(BackwardSumMap* bck_map, ...@@ -133,7 +138,7 @@ void AddGradBySort(BackwardSumMap* bck_map,
VLOG(10) << "add origin_grad: " << target->Name(); VLOG(10) << "add origin_grad: " << target->Name();
VLOG(10) << "added grad: " << var_pair.second->Name() VLOG(10) << "added grad: " << var_pair.second->Name()
<< " trace id is: " << var_pair.first; << " trace id is: " << var_pair.first;
AddTo(var_pair.second, target, current.first); AddTo(var_pair.second, target, current.first, grad_ref);
var_pair.second.reset(); var_pair.second.reset();
} }
} }
...@@ -148,7 +153,6 @@ class Autograd { ...@@ -148,7 +153,6 @@ class Autograd {
} }
VLOG(2) << "start autograd"; VLOG(2) << "start autograd";
BackwardSumMap bck_map; BackwardSumMap bck_map;
GradientRef grad_ref;
std::deque<OpBase*> ready; std::deque<OpBase*> ready;
ready.push_back(var->PreOp()); ready.push_back(var->PreOp());
...@@ -200,12 +204,14 @@ class Autograd { ...@@ -200,12 +204,14 @@ class Autograd {
while (!queue.empty()) { while (!queue.empty()) {
OpBase* candidate = queue.front(); OpBase* candidate = queue.front();
queue.pop_front(); queue.pop_front();
if (bck_stratedy.sorted_sum_gradient_) { for (const auto& map : candidate->grad_output_vars_) {
for (const auto& map : candidate->grad_output_vars_) { for (const auto& it : map) {
for (const auto& it : map) { for (const auto& vb : it.second) {
for (const auto& vb : it.second) { if (bck_stratedy.sorted_sum_gradient_) {
++(*grad_ref)[vb.get()]; ++(*grad_ref)[vb.get()].first;
} }
// init the state of the grad_
(*grad_ref)[vb.get()].second = true;
} }
} }
} }
...@@ -225,6 +231,8 @@ class Autograd { ...@@ -225,6 +231,8 @@ class Autograd {
} }
return ret; return ret;
} }
GradientRef grad_ref;
}; };
std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place, std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
...@@ -382,21 +390,21 @@ std::vector<VarBasePtrMap> OpBase::ApplyGrad( ...@@ -382,21 +390,21 @@ std::vector<VarBasePtrMap> OpBase::ApplyGrad(
grad_ref->find(origin_outputs[i].get()) != grad_ref->end(), grad_ref->find(origin_outputs[i].get()) != grad_ref->end(),
"Can't find %s in grad_reference count map", "Can't find %s in grad_reference count map",
origin_outputs[i]->Name()); origin_outputs[i]->Name());
PADDLE_ENFORCE(grad_ref->at(origin_outputs[i].get()) >= 1, PADDLE_ENFORCE(grad_ref->at(origin_outputs[i].get()).first >= 1,
"Backward error when calculate grad reference"); "Backward error when calculate grad reference");
if (grad_ref->at(origin_outputs[i].get()) > 1) { if (grad_ref->at(origin_outputs[i].get()).first > 1) {
VLOG(10) << "remove ref for " << origin_outputs[i]->Name(); VLOG(10) << "remove ref for " << origin_outputs[i]->Name();
grad_ref->at(origin_outputs[i].get())--; grad_ref->at(origin_outputs[i].get()).first--;
} else { } else {
VLOG(10) << "Add grad for: " << origin_outputs[i]->Name(); VLOG(10) << "Add grad for: " << origin_outputs[i]->Name();
AddGradBySort(bck_map, origin_outputs[i]); AddGradBySort(bck_map, origin_outputs[i], grad_ref);
grad_ref->at(origin_outputs[i].get())--; grad_ref->at(origin_outputs[i].get()).first--;
} }
} else { } else {
VLOG(10) << "AddTo Called with orig_grad is: " VLOG(10) << "AddTo Called with orig_grad is: "
<< origin_outputs[i]->name_ << " Grad to be added is " << origin_outputs[i]->name_ << " Grad to be added is "
<< outputs[i]->name_; << outputs[i]->name_;
AddTo(outputs[i], origin_outputs[i], place_); AddTo(outputs[i], origin_outputs[i], place_, grad_ref);
outputs[i].reset(); outputs[i].reset();
} }
} }
......
...@@ -166,6 +166,7 @@ class VarBase { ...@@ -166,6 +166,7 @@ class VarBase {
if (!var_) { if (!var_) {
var_.reset(new framework::Variable()); var_.reset(new framework::Variable());
} }
auto tensor = var_->GetMutable<framework::LoDTensor>(); auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->Resize(shape); tensor->Resize(shape);
if (need_initialize) { if (need_initialize) {
...@@ -310,13 +311,11 @@ class PYBIND11_HIDDEN OpBase { ...@@ -310,13 +311,11 @@ class PYBIND11_HIDDEN OpBase {
backward_hooks_() {} backward_hooks_() {}
virtual ~OpBase() { virtual ~OpBase() {
for (const auto& iter : outputs_ref) { for (const auto& it : outputs_ref) {
for (const auto& var : iter.second) { auto vb = it.lock();
auto vb = var.lock(); if (vb) {
if (vb) { VLOG(3) << "Op reset by" << vb->name_;
VLOG(3) << "Op reset by" << vb->name_; vb->ResetPreOp(this);
vb->ResetPreOp(this);
}
} }
} }
// TODO(minqiyang): remove op_desc from block_desc in tracer // TODO(minqiyang): remove op_desc from block_desc in tracer
...@@ -372,7 +371,7 @@ class PYBIND11_HIDDEN OpBase { ...@@ -372,7 +371,7 @@ class PYBIND11_HIDDEN OpBase {
OpBasePtrMap pre_ops_; OpBasePtrMap pre_ops_;
std::map<std::string, std::vector<int>> pre_ops_out_idx_; std::map<std::string, std::vector<int>> pre_ops_out_idx_;
VarBaseWeakPtrMap outputs_ref; VarBaseWeakPtrList outputs_ref;
// Inputs to a vector of bwd ops. // Inputs to a vector of bwd ops.
std::vector<VarBasePtrMap> grad_input_vars_; std::vector<VarBasePtrMap> grad_input_vars_;
// Outputs to a vector of bwd ops. // Outputs to a vector of bwd ops.
......
...@@ -172,7 +172,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -172,7 +172,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
outvars.reserve(outputs_tmp.size()); outvars.reserve(outputs_tmp.size());
for (size_t i = 0U; i < outputs_tmp.size(); ++i) { for (size_t i = 0U; i < outputs_tmp.size(); ++i) {
// Add weak_ptr to track outputs // Add weak_ptr to track outputs
op->outputs_ref[it.first].emplace_back(outputs_tmp[i]); op->outputs_ref.emplace_back(outputs_tmp[i]);
std::shared_ptr<imperative::VarBase> out = outputs_tmp[i]; std::shared_ptr<imperative::VarBase> out = outputs_tmp[i];
outvars.emplace_back(out->var_.get()); outvars.emplace_back(out->var_.get());
out->TrackPreOp(op, it.first, i, stop_gradient); out->TrackPreOp(op, it.first, i, stop_gradient);
......
...@@ -29,17 +29,15 @@ class OpBase; ...@@ -29,17 +29,15 @@ class OpBase;
typedef std::map<std::string, std::vector<std::shared_ptr<VarBase>>> typedef std::map<std::string, std::vector<std::shared_ptr<VarBase>>>
VarBasePtrMap; VarBasePtrMap;
typedef std::map<std::string, std::vector<std::weak_ptr<VarBase>>> typedef std::vector<std::weak_ptr<VarBase>> VarBaseWeakPtrList;
VarBaseWeakPtrMap;
typedef std::map<std::string, std::vector<const std::shared_ptr<VarBase>>>
ConstVarBasePtrMap;
typedef std::map<std::string, std::vector<OpBase*>> OpBasePtrMap; typedef std::map<std::string, std::vector<OpBase*>> OpBasePtrMap;
typedef std::unordered_map< typedef std::unordered_map<
const VarBase*, const VarBase*,
std::pair<platform::Place, std::pair<platform::Place,
std::vector<std::pair<int, std::shared_ptr<VarBase>>>>> std::vector<std::pair<int, std::shared_ptr<VarBase>>>>>
BackwardSumMap; // var_grad -> {place, {id -> var_grad@rename}} BackwardSumMap; // var_grad -> {place, {id -> var_grad@rename}}
typedef std::unordered_map<const VarBase*, int> GradientRef; typedef std::unordered_map<const VarBase*, std::pair<int, bool>> GradientRef;
// var_grad -> {ref_times, is_first_to_be_accumulate}
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -54,6 +54,7 @@ class TestRecurrentFeed(unittest.TestCase): ...@@ -54,6 +54,7 @@ class TestRecurrentFeed(unittest.TestCase):
original_in1 = out original_in1 = out
sum_out_value = sum_out.numpy() sum_out_value = sum_out.numpy()
sum_out.backward() sum_out.backward()
dyout = out.gradient()
rt.clear_gradients() rt.clear_gradients()
with new_program_scope(): with new_program_scope():
...@@ -69,7 +70,9 @@ class TestRecurrentFeed(unittest.TestCase): ...@@ -69,7 +70,9 @@ class TestRecurrentFeed(unittest.TestCase):
exe = fluid.Executor(fluid.CPUPlace( exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
fetch_list = [static_sum_out, static_out] static_dout = fluid.default_main_program().block(
0)._find_var_recursive(static_out.name + "@GRAD")
fetch_list = [static_sum_out, static_out, static_dout]
for i in range(3): for i in range(3):
out = exe.run( out = exe.run(
fluid.default_main_program(), fluid.default_main_program(),
...@@ -78,9 +81,11 @@ class TestRecurrentFeed(unittest.TestCase): ...@@ -78,9 +81,11 @@ class TestRecurrentFeed(unittest.TestCase):
fetch_list=fetch_list) fetch_list=fetch_list)
static_out_value = out[1] static_out_value = out[1]
static_sum_out = out[0] static_sum_out = out[0]
static_dout = out[2]
original_np1 = static_out_value original_np1 = static_out_value
self.assertTrue(np.array_equal(static_sum_out, sum_out_value)) self.assertTrue(np.array_equal(static_sum_out, sum_out_value))
self.assertTrue(np.array_equal(static_dout, dyout))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册