未验证 提交 196b0187 编写于 作者: S sneaxiy 提交者: GitHub

Lazy initialize dense_contents_ in reducer (#45631)

* make dense_contents_ lazy init

* update legacy dygraph

* fix legacy dygraph bug
上级 5235ec53
......@@ -290,6 +290,9 @@ static void SplitTensorsWithType(const DeviceContext &context,
}
void EagerGroup::ConcatTensors(const platform::Place &place) {
dense_contents_ =
paddle::experimental::empty(IntArray({all_length_}), dtype_, place);
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto *default_ctx = static_cast<phi::GPUContext *>(
......@@ -452,8 +455,6 @@ void EagerReducer::InitializeGroups(
} else {
// process the dense gradient.
InitializeDenseGroups(tensor_indices_, &group);
group.dense_contents_ = paddle::experimental::empty(
IntArray({group.all_length_}), group.dtype_, inner_place_);
}
// map tensors to this group by VariableLocator
......@@ -908,6 +909,7 @@ void EagerReducer::FinalizeBackward() {
for (auto &group : groups_) {
if (!group.is_sparse_) {
group.SplitTensors(inner_place_);
group.dense_contents_.reset();
}
}
......
......@@ -543,9 +543,6 @@ void Reducer::InitializeGroups(
} else {
// process the dense gradient.
InitializeDenseGroups(variable_indices_, &group);
auto tensor = group.dense_contents_.GetMutable<framework::LoDTensor>();
tensor->Resize(phi::make_ddim({group.all_length_}))
.mutable_data(place_, framework::TransToPhiDataType(group.dtype_));
}
// map variables to this group by VariableLocator
......@@ -954,6 +951,10 @@ void Reducer::MarkGroupReady(size_t group_index) {
UNUSED auto &group = groups_[next_group_];
UNUSED const int run_order = next_group_ % nrings_;
auto *tensor = group.dense_contents_.GetMutable<framework::LoDTensor>();
tensor->Resize(phi::make_ddim({group.all_length_}))
.mutable_data(place_, framework::TransToPhiDataType(group.dtype_));
// For CUDA or XPU, compute_stream --> comm_stream.
// For CPU, do nothing.
// NOTE. Because concat uses the comm_stream,
......@@ -1116,6 +1117,12 @@ void Reducer::FinalizeBackward() {
parallel_ctx_->WaitComm(i);
}
for (auto &group : groups_) {
if (!group.is_sparse_) {
group.dense_contents_.Clear();
}
}
if (NeedRebuildGroup()) {
VLOG(3) << "Start rebuilding the groups";
auto rebuild_group_indices = RebuildGruops();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册