未验证 提交 e77c062e 编写于 作者: H Haohongxiang 提交者: GitHub

[Dygraph] Finish fixing mem bugs of no sync in DataParallel (#47444)

上级 315ef265
......@@ -588,10 +588,9 @@ void EagerReducer::TraverseBackwardGraph(const std::vector<Tensor> &outputs) {
}
}
void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs,
const bool is_sync) {
void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) {
VLOG(3) << "after forward, then reset count for backward.";
grad_need_hooks_ = is_sync;
grad_need_hooks_ = true;
next_group_ = 0;
std::for_each(groups_.begin(), groups_.end(), [](EagerGroup &group) {
......@@ -660,9 +659,25 @@ void EagerReducer::AddDistHook(size_t var_index) {
var_index));
// gradient synchronization is not required when grad_need_hooks_ is false.
// if (!grad_need_hooks_) {
// return;
// }
if (!grad_need_hooks_) {
const auto &var_locator = variable_locators_[var_index];
const auto group_index = var_locator.group_index;
const auto inside_group_index = var_locator.inside_group_index;
auto &group = groups_[group_index];
auto &group_tensor = group.dense_tensors_[inside_group_index];
auto *autograd_meta = tensors_[var_index].get_autograd_meta();
auto &grad_tensor = static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
if (!HasGrad(var_index)) {
group_tensor.ShareDataWith(phi::DenseTensor());
} else {
auto grad_dense_tensor =
*(std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor.impl()));
group_tensor.ShareDataWith(grad_dense_tensor);
}
return;
}
VLOG(3) << "Tensor[" << var_index << "] [" << tensors_[var_index].name()
<< "@Grad] arrived and triggered disthook";
......@@ -828,12 +843,10 @@ void EagerReducer::MarkGroupReady(size_t group_index) {
for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0;
++next_group_) {
UNUSED auto &group = groups_[next_group_];
if (grad_need_hooks_) {
if (group.is_sparse_) {
AllReduceSparse(&group, next_group_);
} else {
FusedAllReduceSchedule(&group, next_group_);
}
if (group.is_sparse_) {
AllReduceSparse(&group, next_group_);
} else {
FusedAllReduceSchedule(&group, next_group_);
}
}
}
......@@ -921,14 +934,15 @@ void EagerReducer::ProcessUnusedDenseVars() {
void EagerReducer::FinalizeBackward() {
groups_need_finalize_ = false;
grad_need_hooks_ = false;
for (auto &group : groups_) {
if (!group.is_sparse_ && grad_need_hooks_) {
if (!group.is_sparse_) {
group.task->Synchronize();
}
}
for (auto &group : groups_) {
if (!group.is_sparse_ && grad_need_hooks_) {
if (!group.is_sparse_) {
group.dense_contents_.reset();
}
}
......@@ -940,7 +954,6 @@ void EagerReducer::FinalizeBackward() {
VLOG(3) << "ProcessUnusedDenseVars is finished.";
}
grad_need_hooks_ = false;
VLOG(3) << "In the batch, Reducer is finished.";
}
......
......@@ -103,8 +103,7 @@ class EagerReducer {
void InitializeGroups(const std::vector<std::vector<size_t>> &group_indices);
void InitializeDenseGroups(const std::vector<size_t> &tensor_indices_,
EagerGroup *p_group);
void PrepareForBackward(const std::vector<Tensor> &outputs,
const bool is_sync);
void PrepareForBackward(const std::vector<Tensor> &outputs);
void AddDistHook(size_t var_index);
void MarkVarReady(const size_t var_index, const bool is_used_var);
void MarkGroupReady(const size_t group_index);
......
......@@ -675,10 +675,9 @@ void Reducer::TraverseBackwardGraph(
// After each batch is calculated, the counter of each group(group.pending_)
// and allreudce sequence counter(next_group_) will be cleaned up again.
void Reducer::PrepareForBackward(
const std::vector<std::shared_ptr<imperative::VarBase>> &outputs,
const bool is_sync) {
const std::vector<std::shared_ptr<imperative::VarBase>> &outputs) {
VLOG(3) << "after forward, then reset count for backward.";
grad_need_hooks_ = is_sync;
grad_need_hooks_ = true;
next_group_ = 0;
std::for_each(groups_.begin(), groups_.end(), [](Group &group) {
group.pending_ = group.variable_indices_.size();
......@@ -711,9 +710,7 @@ void Reducer::PrepareForBackward(
if (find_unused_vars_once_ || find_unused_vars_each_step_) {
unused_vars_.clear();
if (grad_need_hooks_) {
TraverseBackwardGraph(outputs);
}
TraverseBackwardGraph(outputs);
// only check once in first step
find_unused_vars_once_ = false;
}
......
......@@ -146,8 +146,7 @@ class Reducer {
void PrepareDeps(const std::unordered_set<GradOpNode*>& init_nodes);
void PrepareForBackward(
const std::vector<std::shared_ptr<imperative::VarBase>>& outputs,
const bool is_sync);
const std::vector<std::shared_ptr<imperative::VarBase>>& outputs);
void AddDistHook(size_t var_index);
......
......@@ -1407,14 +1407,11 @@ void BindDistributed(py::module *m) {
.def(py::init(&CreateEagerReducer))
.def(
"prepare_for_backward",
[](distributed::EagerReducer &self,
py::handle py_tensors,
bool is_sync) {
[](distributed::EagerReducer &self, py::handle py_tensors) {
auto params = CastPyArg2VectorOfTensor(py_tensors.ptr(), 0);
self.PrepareForBackward(params, is_sync);
self.PrepareForBackward(params);
},
py::arg("tensors"),
py::arg("is_sync"),
py::call_guard<py::gil_scoped_release>());
}
......
......@@ -2569,7 +2569,6 @@ void BindImperative(py::module *m_ptr) {
.def("prepare_for_backward",
&imperative::Reducer::PrepareForBackward,
py::arg("vars"),
py::arg("is_sync"),
py::call_guard<py::gil_scoped_release>());
m.def("assign_group_by_size",
......
......@@ -818,9 +818,13 @@ class DataParallel(layers.Layer):
def forward(self, *inputs, **kwargs):
outputs = self._layers(*inputs, **kwargs)
if self._strategy.nranks > 1 and framework._dygraph_tracer()._has_grad:
if (
self._strategy.nranks > 1
and framework._dygraph_tracer()._has_grad
and self.grad_need_sync
):
self._reducer.prepare_for_backward(
list(self._find_varbase(outputs)), self.grad_need_sync
list(self._find_varbase(outputs))
)
return outputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册