未验证 提交 0175f566 编写于 作者: S ShenLiang 提交者: GitHub

[cherry-pick] Solve inconsistent order in each card in dynamic graph (#30965)

* support if else control

* fix conflict
上级 aaaae6b4
......@@ -22,11 +22,6 @@ std::shared_ptr<Reducer> Reducer::s_instance_ = NULL;
// context is used to select the stream for concat
void Group::ConcatTensors(const platform::CUDADeviceContext &context) {
VLOG(3) << "Before concat, set output tensor size is " << all_length_;
auto tensor = dense_contents_.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({all_length_}))
.mutable_data(context.GetPlace(), dtype_);
switch (dtype_) {
case framework::proto::VarType::FP16:
ConcatTensorsForAllReduce<platform::float16>(context, dense_tensors_,
......@@ -179,6 +174,9 @@ void Reducer::InitializeDenseGroups(
p_group->length_.push_back(size);
// for concat operator
p_group->dense_tensors_.push_back(framework::Tensor());
// check the dtype and place, it must be same.
auto dtype = var->DataType();
auto place = var->Place();
......@@ -200,6 +198,7 @@ void Reducer::InitializeDenseGroups(
place_ = place;
}
}
p_group->all_length_ = all_length;
}
// Each parameter will be initialized according to the group information.
......@@ -234,6 +233,9 @@ void Reducer::InitializeGroups(
} else {
// process the dense gradient.
InitializeDenseGroups(variable_indices_, &group);
auto tensor = group.dense_contents_.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({group.all_length_}))
.mutable_data(place_, group.dtype_);
}
// map variables to this group by VariableLocator
......@@ -295,9 +297,6 @@ void Reducer::PrepareForBackward(
next_group_ = 0;
std::for_each(groups_.begin(), groups_.end(), [](Group &group) {
group.pending_ = group.variable_indices_.size();
group.all_length_ = 0;
group.dense_tensors_.clear();
group.dense_tensors_.reserve(group.pending_);
group.sparse_contents_ = nullptr;
});
......@@ -423,22 +422,35 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) {
auto group_index = var_locator.group_index;
auto &group = groups_[group_index];
if (is_used_var) {
auto var_warpper = vars_[var_index]->GradVarBase()->SharedVar();
if (!group.is_sparse_) {
auto grad = var_warpper->MutableVar();
auto inside_group_index = var_locator.inside_group_index;
auto length = group.length_[inside_group_index];
auto tensor = grad->GetMutable<framework::LoDTensor>();
framework::Tensor tmp;
tmp.ShareDataWith(*tensor).Resize({static_cast<int64_t>(length)});
group.dense_tensors_.push_back(std::move(tmp));
group.all_length_ += length;
if (!group.is_sparse_) {
// process dense group
auto inside_group_index = var_locator.inside_group_index;
auto length = group.length_[inside_group_index];
auto &group_tensor = group.dense_tensors_[inside_group_index];
if (is_used_var) {
auto var_warpper = vars_[var_index]->GradVarBase()->SharedVar();
auto tensor =
var_warpper->MutableVar()->GetMutable<framework::LoDTensor>();
group_tensor.ShareDataWith(*tensor).Resize(
{static_cast<int64_t>(length)});
} else {
if (!group_tensor.IsInitialized()) {
group_tensor.Resize({static_cast<int64_t>(length)});
group_tensor.mutable_data(place_, group.dtype_);
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place_);
operators::math::set_constant(*dev_ctx, &group_tensor, 0.0);
}
}
} else {
// process sparse group
if (is_used_var) {
auto var_warpper = vars_[var_index]->GradVarBase()->SharedVar();
group.sparse_contents_ = var_warpper->MutableVar();
} else {
group.sparse_contents_ = nullptr;
}
}
if (--group.pending_ == 0) {
// can start allreduce
MarkGroupReady(group_index);
......@@ -478,24 +490,19 @@ void Reducer::MarkGroupReady(size_t group_index) {
<< "] has no var to allreduce";
}
} else {
if (!group.dense_tensors_.empty()) {
VLOG(3) << "dense group [" << next_group_
<< "] start allreduce in ring[" << run_order << "]";
// Select common commstream to concat tensors
// group.dense_tensors ---> group.dense_contents_
group.ConcatTensors(*parallel_ctx_->GetDeviceContext(run_order));
// Start allreduce
parallel_ctx_->AllReduceByStream(
group.dense_contents_, &(group.dense_contents_), run_order, false);
// Select common commstream to split tensors
// group.dense_contents_ ---> group.dense_tensors
group.SplitTensors(*parallel_ctx_->GetDeviceContext(run_order));
} else {
VLOG(3) << "The dense group[" << next_group_
<< "] has no var to allreduce";
}
VLOG(3) << "dense group [" << next_group_ << "] start allreduce in ring["
<< run_order << "]";
// Select common commstream to concat tensors
// group.dense_tensors ---> group.dense_contents_
group.ConcatTensors(*parallel_ctx_->GetDeviceContext(run_order));
// Start allreduce
parallel_ctx_->AllReduceByStream(
group.dense_contents_, &(group.dense_contents_), run_order, false);
// Select common commstream to split tensors
// group.dense_contents_ ---> group.dense_tensors
group.SplitTensors(*parallel_ctx_->GetDeviceContext(run_order));
}
}
}
......
......@@ -29,6 +29,7 @@
#include "paddle/fluid/imperative/op_base.h"
#include "paddle/fluid/imperative/variable_wrapper.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/string/string_helper.h"
#if defined(PADDLE_WITH_NCCL)
......@@ -201,7 +202,8 @@ class Reducer {
int nrings_ = 1;
// Following variables are to help rebuild group
bool has_rebuilt_group_{false};
// TODO(shenliang03): Support rebuild in the future.
bool has_rebuilt_group_{true};
std::vector<std::shared_ptr<imperative::VarBase>> rebuild_vars_;
std::vector<int64_t> rebuild_var_indices_;
const std::vector<size_t> group_size_limits_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册