未验证 提交 dae3e1f3 编写于 作者: S ShenLiang 提交者: GitHub

Solve inconsistent order in each card in dynamic graph (#30931)

上级 14d039e4
......@@ -181,11 +181,6 @@ void SplitTensorsWithType<platform::XPUDeviceContext>(
#endif
void Group::ConcatTensors(const platform::DeviceContext &context) {
VLOG(3) << "Before concat, set output tensor size is " << all_length_;
auto tensor = dense_contents_.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({all_length_}))
.mutable_data(context.GetPlace(), dtype_);
auto place = context.GetPlace();
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_NCCL
......@@ -320,6 +315,9 @@ void Reducer::InitializeDenseGroups(
p_group->length_.push_back(size);
// for concat operator
p_group->dense_tensors_.push_back(framework::Tensor());
// check the dtype and place, it must be same.
auto dtype = var->DataType();
auto place = var->Place();
......@@ -341,6 +339,7 @@ void Reducer::InitializeDenseGroups(
place_ = place;
}
}
p_group->all_length_ = all_length;
}
// Each parameter will be initialized according to the group information.
......@@ -375,6 +374,9 @@ void Reducer::InitializeGroups(
} else {
// process the dense gradient.
InitializeDenseGroups(variable_indices_, &group);
auto tensor = group.dense_contents_.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({group.all_length_}))
.mutable_data(place_, group.dtype_);
}
// map variables to this group by VariableLocator
......@@ -436,9 +438,6 @@ void Reducer::PrepareForBackward(
next_group_ = 0;
std::for_each(groups_.begin(), groups_.end(), [](Group &group) {
group.pending_ = group.variable_indices_.size();
group.all_length_ = 0;
group.dense_tensors_.clear();
group.dense_tensors_.reserve(group.pending_);
group.sparse_contents_ = nullptr;
});
......@@ -564,22 +563,42 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) {
auto group_index = var_locator.group_index;
auto &group = groups_[group_index];
if (is_used_var) {
auto var_warpper = vars_[var_index]->GradVarBase()->SharedVar();
if (!group.is_sparse_) {
auto grad = var_warpper->MutableVar();
auto inside_group_index = var_locator.inside_group_index;
auto length = group.length_[inside_group_index];
auto tensor = grad->GetMutable<framework::LoDTensor>();
framework::Tensor tmp;
tmp.ShareDataWith(*tensor).Resize({static_cast<int64_t>(length)});
group.dense_tensors_.push_back(std::move(tmp));
group.all_length_ += length;
if (!group.is_sparse_) {
// process dense group
auto inside_group_index = var_locator.inside_group_index;
auto length = group.length_[inside_group_index];
auto &group_tensor = group.dense_tensors_[inside_group_index];
if (is_used_var) {
auto var_warpper = vars_[var_index]->GradVarBase()->SharedVar();
auto tensor =
var_warpper->MutableVar()->GetMutable<framework::LoDTensor>();
group_tensor.ShareDataWith(*tensor).Resize(
{static_cast<int64_t>(length)});
} else {
if (!group_tensor.IsInitialized()) {
group_tensor.Resize({static_cast<int64_t>(length)});
group_tensor.mutable_data(place_, group.dtype_);
#ifdef PADDLE_WITH_XPU_BKCL
if (platform::is_xpu_place(group_tensor.place())) {
// TODO(liuyuhui) support XPU set constant
VLOG(3) << "XPU doesn't support set_constant";
}
#else
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place_);
operators::math::set_constant(*dev_ctx, &group_tensor, 0.0);
#endif
}
}
} else {
// process sparse group
if (is_used_var) {
auto var_warpper = vars_[var_index]->GradVarBase()->SharedVar();
group.sparse_contents_ = var_warpper->MutableVar();
} else {
group.sparse_contents_ = nullptr;
}
}
if (--group.pending_ == 0) {
// can start allreduce
MarkGroupReady(group_index);
......@@ -619,36 +638,30 @@ void Reducer::MarkGroupReady(size_t group_index) {
<< "] has no var to allreduce";
}
} else {
if (!group.dense_tensors_.empty()) {
VLOG(3) << "dense group [" << next_group_
<< "] start allreduce in ring[" << run_order << "]";
// Select common commstream to concat tensors
// group.dense_tensors ---> group.dense_contents_
group.ConcatTensors(*parallel_ctx_->GetDeviceContext(run_order));
VLOG(3) << "dense group [" << next_group_ << "] start allreduce in ring["
<< run_order << "]";
// Select common commstream to concat tensors
// group.dense_tensors ---> group.dense_contents_
group.ConcatTensors(*parallel_ctx_->GetDeviceContext(run_order));
// NOTE(liuyuhui): ConcatTensors use communication stream, but BKCL only support
// default stream for communicating,
// so there exist some problems in synchronization. And need to add a WaitComm
// there.
// default stream for communicating, so there exist some problems in
// synchronization. And need to add a WaitComm there.
// TODO(liuyuhui): If BKCL support events, it should be fixed as non-blocking
// communication.
#ifdef PADDLE_WITH_XPU_BKCL
if (platform::is_xpu_place(group.dense_tensors_[0].place())) {
parallel_ctx_->WaitComm(run_order);
}
if (platform::is_xpu_place(group.dense_tensors_[0].place())) {
parallel_ctx_->WaitComm(run_order);
}
#endif
// Start allreduce
parallel_ctx_->AllReduceByStream(
group.dense_contents_, &(group.dense_contents_), run_order, false);
// Start allreduce
parallel_ctx_->AllReduceByStream(
group.dense_contents_, &(group.dense_contents_), run_order, false);
// Select common commstream to split tensors
// group.dense_contents_ ---> group.dense_tensors
group.SplitTensors(*parallel_ctx_->GetDeviceContext(run_order));
} else {
VLOG(3) << "The dense group[" << next_group_
<< "] has no var to allreduce";
}
// Select common commstream to split tensors
// group.dense_contents_ ---> group.dense_tensors
group.SplitTensors(*parallel_ctx_->GetDeviceContext(run_order));
}
}
}
......
......@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace platform {
......@@ -133,7 +134,8 @@ class Reducer {
int nrings_ = 1;
// Following variables are to help rebuild group
bool has_rebuilt_group_{false};
// TODO(shenliang03): Support rebuild in the future.
bool has_rebuilt_group_{true};
std::vector<std::shared_ptr<imperative::VarBase>> rebuild_vars_;
std::vector<int64_t> rebuild_var_indices_;
const std::vector<size_t> group_size_limits_;
......
......@@ -94,9 +94,11 @@ void GroupConcatSplit(Place place, size_t size) {
auto* dev_ctx = pool.Get(place);
{ // concat
auto* tensor = group.dense_contents_.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({group.all_length_}))
.mutable_data(place, group.dtype_);
group.ConcatTensors(*dev_ctx);
auto* tensor = group.dense_contents_.GetMutable<framework::LoDTensor>();
framework::Tensor tmp;
framework::TensorCopySync(*tensor, cpu_place, &tmp);
auto* data = tmp.data<T>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册