未验证 提交 3d3f5506 编写于 作者: J Jiabin Yang 提交者: GitHub

Feature/Fix recurrent usage of Varbase in Dygraph (#17838)

* for debug

* test=develop, memory optimize for dygraph using shared_ptr

* test=develop, fix travis ci showed error

* test=develop, fix bug for recurrent usage of varbase

* test=develop, init varbase when it need to be Add

* test=develop, fix problem of recurrent gradient

* test=develop, add gradient test for recurrent varbase usage
上级 60094207
......@@ -79,12 +79,16 @@ class TensorAddToFunctor : public boost::static_visitor<> {
} // namespace detail
void AddTo(std::shared_ptr<VarBase> src, std::shared_ptr<VarBase> dst,
platform::Place place) {
if (!dst->IsInitialize()) {
VLOG(2) << "im here1";
platform::Place place, GradientRef* grad_ref) {
PADDLE_ENFORCE(grad_ref->find(dst.get()) != grad_ref->end(),
"gradient %s are not found in grad_ref", dst->Name());
if ((*grad_ref)[dst.get()].second) {
PADDLE_ENFORCE(src->IsInitialize(), "Using uninitialized VarBase");
dst->var_ = std::move(src->var_);
(*grad_ref)[dst.get()].second = false;
if (!dst->IsInitialize()) {
dst->SetInitialize(true);
}
return;
} else {
framework::Tensor* dst_tensor =
......@@ -118,7 +122,8 @@ void ZeroGrads(const std::shared_ptr<imperative::VarBase> vb,
}
void AddGradBySort(BackwardSumMap* bck_map,
std::shared_ptr<imperative::VarBase> target) {
std::shared_ptr<imperative::VarBase> target,
GradientRef* grad_ref) {
PADDLE_ENFORCE(bck_map->find(target.get()) != bck_map->end(),
"Can't find %s in backward grad map", target->Name());
std::pair<platform::Place,
......@@ -133,7 +138,7 @@ void AddGradBySort(BackwardSumMap* bck_map,
VLOG(10) << "add origin_grad: " << target->Name();
VLOG(10) << "added grad: " << var_pair.second->Name()
<< " trace id is: " << var_pair.first;
AddTo(var_pair.second, target, current.first);
AddTo(var_pair.second, target, current.first, grad_ref);
var_pair.second.reset();
}
}
......@@ -148,7 +153,6 @@ class Autograd {
}
VLOG(2) << "start autograd";
BackwardSumMap bck_map;
GradientRef grad_ref;
std::deque<OpBase*> ready;
ready.push_back(var->PreOp());
......@@ -200,12 +204,14 @@ class Autograd {
while (!queue.empty()) {
OpBase* candidate = queue.front();
queue.pop_front();
if (bck_stratedy.sorted_sum_gradient_) {
for (const auto& map : candidate->grad_output_vars_) {
for (const auto& it : map) {
for (const auto& vb : it.second) {
++(*grad_ref)[vb.get()];
if (bck_stratedy.sorted_sum_gradient_) {
++(*grad_ref)[vb.get()].first;
}
// init the state of the grad_
(*grad_ref)[vb.get()].second = true;
}
}
}
......@@ -225,6 +231,8 @@ class Autograd {
}
return ret;
}
GradientRef grad_ref;
};
std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
......@@ -382,21 +390,21 @@ std::vector<VarBasePtrMap> OpBase::ApplyGrad(
grad_ref->find(origin_outputs[i].get()) != grad_ref->end(),
"Can't find %s in grad_reference count map",
origin_outputs[i]->Name());
PADDLE_ENFORCE(grad_ref->at(origin_outputs[i].get()) >= 1,
PADDLE_ENFORCE(grad_ref->at(origin_outputs[i].get()).first >= 1,
"Backward error when calculate grad reference");
if (grad_ref->at(origin_outputs[i].get()) > 1) {
if (grad_ref->at(origin_outputs[i].get()).first > 1) {
VLOG(10) << "remove ref for " << origin_outputs[i]->Name();
grad_ref->at(origin_outputs[i].get())--;
grad_ref->at(origin_outputs[i].get()).first--;
} else {
VLOG(10) << "Add grad for: " << origin_outputs[i]->Name();
AddGradBySort(bck_map, origin_outputs[i]);
grad_ref->at(origin_outputs[i].get())--;
AddGradBySort(bck_map, origin_outputs[i], grad_ref);
grad_ref->at(origin_outputs[i].get()).first--;
}
} else {
VLOG(10) << "AddTo Called with orig_grad is: "
<< origin_outputs[i]->name_ << " Grad to be added is "
<< outputs[i]->name_;
AddTo(outputs[i], origin_outputs[i], place_);
AddTo(outputs[i], origin_outputs[i], place_, grad_ref);
outputs[i].reset();
}
}
......
......@@ -166,6 +166,7 @@ class VarBase {
if (!var_) {
var_.reset(new framework::Variable());
}
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->Resize(shape);
if (need_initialize) {
......@@ -310,15 +311,13 @@ class PYBIND11_HIDDEN OpBase {
backward_hooks_() {}
virtual ~OpBase() {
for (const auto& iter : outputs_ref) {
for (const auto& var : iter.second) {
auto vb = var.lock();
for (const auto& it : outputs_ref) {
auto vb = it.lock();
if (vb) {
VLOG(3) << "Op reset by" << vb->name_;
vb->ResetPreOp(this);
}
}
}
// TODO(minqiyang): remove op_desc from block_desc in tracer
// release resource
for (framework::OpDesc* desc : grad_op_descs_) {
......@@ -372,7 +371,7 @@ class PYBIND11_HIDDEN OpBase {
OpBasePtrMap pre_ops_;
std::map<std::string, std::vector<int>> pre_ops_out_idx_;
VarBaseWeakPtrMap outputs_ref;
VarBaseWeakPtrList outputs_ref;
// Inputs to a vector of bwd ops.
std::vector<VarBasePtrMap> grad_input_vars_;
// Outputs to a vector of bwd ops.
......
......@@ -172,7 +172,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
outvars.reserve(outputs_tmp.size());
for (size_t i = 0U; i < outputs_tmp.size(); ++i) {
// Add weak_ptr to track outputs
op->outputs_ref[it.first].emplace_back(outputs_tmp[i]);
op->outputs_ref.emplace_back(outputs_tmp[i]);
std::shared_ptr<imperative::VarBase> out = outputs_tmp[i];
outvars.emplace_back(out->var_.get());
out->TrackPreOp(op, it.first, i, stop_gradient);
......
......@@ -29,17 +29,15 @@ class OpBase;
typedef std::map<std::string, std::vector<std::shared_ptr<VarBase>>>
VarBasePtrMap;
typedef std::map<std::string, std::vector<std::weak_ptr<VarBase>>>
VarBaseWeakPtrMap;
typedef std::map<std::string, std::vector<const std::shared_ptr<VarBase>>>
ConstVarBasePtrMap;
typedef std::vector<std::weak_ptr<VarBase>> VarBaseWeakPtrList;
typedef std::map<std::string, std::vector<OpBase*>> OpBasePtrMap;
typedef std::unordered_map<
const VarBase*,
std::pair<platform::Place,
std::vector<std::pair<int, std::shared_ptr<VarBase>>>>>
BackwardSumMap; // var_grad -> {place, {id -> var_grad@rename}}
typedef std::unordered_map<const VarBase*, int> GradientRef;
typedef std::unordered_map<const VarBase*, std::pair<int, bool>> GradientRef;
// var_grad -> {ref_times, is_first_to_be_accumulate}
} // namespace imperative
} // namespace paddle
......@@ -54,6 +54,7 @@ class TestRecurrentFeed(unittest.TestCase):
original_in1 = out
sum_out_value = sum_out.numpy()
sum_out.backward()
dyout = out.gradient()
rt.clear_gradients()
with new_program_scope():
......@@ -69,7 +70,9 @@ class TestRecurrentFeed(unittest.TestCase):
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
fetch_list = [static_sum_out, static_out]
static_dout = fluid.default_main_program().block(
0)._find_var_recursive(static_out.name + "@GRAD")
fetch_list = [static_sum_out, static_out, static_dout]
for i in range(3):
out = exe.run(
fluid.default_main_program(),
......@@ -78,9 +81,11 @@ class TestRecurrentFeed(unittest.TestCase):
fetch_list=fetch_list)
static_out_value = out[1]
static_sum_out = out[0]
static_dout = out[2]
original_np1 = static_out_value
self.assertTrue(np.array_equal(static_sum_out, sum_out_value))
self.assertTrue(np.array_equal(static_dout, dyout))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册