未验证 提交 df0430dc 编写于 作者: S ShenLiang 提交者: GitHub

[cherry-pick]fix matmulv2 bug & add rebuild group & fix bug of download (#29726)

* Fix the dowanload bug in the case of multiple machines (#29551)

* fix the dowanload bug
* add sort for ips

* Fix bug of matmul_v2 for broadcast case (#29599)

* fix bug of matmul_v2 for broadcast

* Rebuild group automatically in dynamic graph distributed (#29255)

* add tensor_indices in AssignGroupBySize

* add rebuild group in reducer

* fix error message of gather nd (#29521)
上级 ef04d3d3
...@@ -20,47 +20,98 @@ namespace imperative { ...@@ -20,47 +20,98 @@ namespace imperative {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
std::shared_ptr<Reducer> Reducer::s_instance_ = NULL; std::shared_ptr<Reducer> Reducer::s_instance_ = NULL;
// context is used to select the stream for concat
void Group::ConcatTensors(const platform::CUDADeviceContext &context) {
switch (dtype_) {
case framework::proto::VarType::FP16:
ConcatTensorsForAllReduce<platform::float16>(context, dense_tensors_,
&dense_contents_);
break;
case framework::proto::VarType::FP32:
ConcatTensorsForAllReduce<float>(context, dense_tensors_,
&dense_contents_);
break;
case framework::proto::VarType::FP64:
ConcatTensorsForAllReduce<double>(context, dense_tensors_,
&dense_contents_);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it concats tensors for "
"allreduce.",
framework::DataTypeToString(dtype_)));
}
}
// context is used to select the stream for split
void Group::SplitTensors(const platform::CUDADeviceContext &context) {
switch (dtype_) {
case framework::proto::VarType::FP16:
SplitTensorsForAllReduce<platform::float16>(context, &dense_contents_,
&dense_tensors_);
break;
case framework::proto::VarType::FP32:
SplitTensorsForAllReduce<float>(context, &dense_contents_,
&dense_tensors_);
break;
case framework::proto::VarType::FP64:
SplitTensorsForAllReduce<double>(context, &dense_contents_,
&dense_tensors_);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it splits tensors for "
"allreduce.",
framework::DataTypeToString(dtype_)));
}
}
std::ostream &operator<<(std::ostream &out, const Group &group) {
const auto &vars = group.variable_indices_;
out << "numul: " << group.all_length_ << " ;is_sparse: " << group.is_sparse_
<< " ;var number: " << vars.size() << "\n";
auto begin = vars.begin();
auto end = vars.end();
out << "[";
for (int i = 0; begin != end && i < 100; ++i, ++begin) {
if (i > 0) out << ' ';
out << *begin;
}
if (begin != end) {
out << " ...";
}
out << "]\n";
return out;
}
Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars, Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
const std::vector<std::vector<size_t>> &group_indices, const std::vector<std::vector<size_t>> &group_indices,
const std::vector<bool> &is_sparse_gradient, const std::vector<bool> &is_sparse_gradient,
std::shared_ptr<imperative::ParallelContext> parallel_ctx) std::shared_ptr<imperative::ParallelContext> parallel_ctx,
const std::vector<size_t> &group_size_limits)
: vars_(vars), : vars_(vars),
group_indices_(group_indices), group_indices_(group_indices),
is_sparse_gradient_(is_sparse_gradient), is_sparse_gradient_(is_sparse_gradient),
parallel_ctx_(parallel_ctx) { parallel_ctx_(parallel_ctx),
group_size_limits_(group_size_limits) {
VLOG(3) << "Start construct the Reducer ..."; VLOG(3) << "Start construct the Reducer ...";
// initialize groups // initialize groups
InitializeGroups(group_indices); InitializeGroups(group_indices);
for (size_t global_var_index = 0; global_var_index < vars_.size();
{ ++global_var_index) {
for (size_t group_index = 0; group_index < group_indices.size();
++group_index) {
for (size_t var_index = 0; var_index < group_indices[group_index].size();
++var_index) {
size_t global_var_index = group_indices[group_index][var_index];
const auto variable_index = VariableIndex{
.group_index = group_index, .inside_group_index = var_index,
};
VLOG(3) << "add hook for var[" << vars_[global_var_index]->GradVarName()
<< "], it's in group [" << group_index << "]";
vars_[global_var_index]->SharedVar()->AddGradVarLeafBackwardHook( vars_[global_var_index]->SharedVar()->AddGradVarLeafBackwardHook(
std::unique_ptr<LambdaGradAccumulatorPostHook>( std::unique_ptr<LambdaGradAccumulatorPostHook>(
new LambdaGradAccumulatorPostHook([=](VariableWrapper *grad) { new LambdaGradAccumulatorPostHook([=](VariableWrapper *grad) {
this->AddDistHook(grad, variable_index); this->AddDistHook(grad, global_var_index);
}))); })));
} }
} // create streams
}
compute_stream_ = static_cast<platform::CUDADeviceContext *>( compute_stream_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_)) platform::DeviceContextPool::Instance().Get(place_))
->stream(); ->stream();
comm_stream_ = platform::NCCLCommContext::Instance().Get(0, place_)->stream(); comm_stream_ = platform::NCCLCommContext::Instance().Get(0, place_)->stream();
events_.resize(group_indices.size()); // create events
for (auto &event : events_) { CreateGroupEvents(group_indices.size());
event = platform::CudaEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::CUDAPlace, place_).device);
}
comm_enent_ = platform::CudaEventResourcePool::Instance().New( comm_enent_ = platform::CudaEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::CUDAPlace, place_).device); BOOST_GET_CONST(platform::CUDAPlace, place_).device);
...@@ -76,7 +127,20 @@ void Reducer::ReleaseReducer() { ...@@ -76,7 +127,20 @@ void Reducer::ReleaseReducer() {
comm_enent_.reset(); comm_enent_.reset();
} }
int64_t Reducer::InitializeDenseGroups( void Reducer::CreateGroupEvents(int group_num) {
// release old events
for (auto &event : events_) {
event.reset();
}
events_.clear();
events_.resize(group_num);
for (auto &event : events_) {
event = platform::CudaEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::CUDAPlace, place_).device);
}
}
void Reducer::InitializeDenseGroups(
const std::vector<size_t> &variable_indices_, Group *p_group) { const std::vector<size_t> &variable_indices_, Group *p_group) {
int64_t all_length = 0; int64_t all_length = 0;
for (size_t index = 0; index < variable_indices_.size(); ++index) { for (size_t index = 0; index < variable_indices_.size(); ++index) {
...@@ -85,18 +149,18 @@ int64_t Reducer::InitializeDenseGroups( ...@@ -85,18 +149,18 @@ int64_t Reducer::InitializeDenseGroups(
const auto var_name = var->Name(); const auto var_name = var->Name();
PADDLE_ENFORCE_EQ(is_sparse_gradient_[variable_index], false, PADDLE_ENFORCE_EQ(is_sparse_gradient_[variable_index], false,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Tensor `%s`'s GRAD must be LoDTensor, but received " "Tensor %s's GRAD must be LoDTensor, but received "
"GRAD is SelectedRows", "GRAD is SelectedRows",
var_name)); var_name));
auto lod_tensor = var->MutableVar()->GetMutable<framework::LoDTensor>(); auto lod_tensor = var->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(lod_tensor->IsInitialized(), true, PADDLE_ENFORCE_EQ(lod_tensor->IsInitialized(), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Tensor `%s` is not initialized.", var_name)); "Tensor %s is not initialized.", var_name));
auto size = lod_tensor->numel(); auto size = lod_tensor->numel();
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
size, 0, platform::errors::PreconditionNotMet( size, 0, platform::errors::PreconditionNotMet(
"The number of tensor `%s`'s elements is 0.", var_name)); "The number of tensor %s's elements is 0.", var_name));
all_length += size; all_length += size;
p_group->length_.push_back(size); p_group->length_.push_back(size);
...@@ -124,7 +188,7 @@ int64_t Reducer::InitializeDenseGroups( ...@@ -124,7 +188,7 @@ int64_t Reducer::InitializeDenseGroups(
place_ = place; place_ = place;
} }
} }
return all_length; p_group->all_length_ = all_length;
} }
// Each parameter will be initialized according to the group information. // Each parameter will be initialized according to the group information.
...@@ -137,6 +201,8 @@ void Reducer::InitializeGroups( ...@@ -137,6 +201,8 @@ void Reducer::InitializeGroups(
// clear the group // clear the group
groups_.clear(); groups_.clear();
groups_.reserve(group_indices.size()); groups_.reserve(group_indices.size());
variable_locators_.clear();
variable_locators_.resize(vars_.size());
auto group_nums = group_indices.size(); auto group_nums = group_indices.size();
for (size_t group_index = 0; group_index < group_nums; ++group_index) { for (size_t group_index = 0; group_index < group_nums; ++group_index) {
...@@ -144,10 +210,8 @@ void Reducer::InitializeGroups( ...@@ -144,10 +210,8 @@ void Reducer::InitializeGroups(
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
variable_indices_.size(), 0, variable_indices_.size(), 0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The number of group_index[`%d`]'s elements is 0.", group_index)); "The number of group[%d]'s elements is 0.", group_index));
Group group; Group group;
group.variable_indices_ = variable_indices_;
int64_t all_length = 0;
// It's just for check the sparse or dense // It's just for check the sparse or dense
auto first_varbase = vars_[variable_indices_.front()]; auto first_varbase = vars_[variable_indices_.front()];
...@@ -159,17 +223,27 @@ void Reducer::InitializeGroups( ...@@ -159,17 +223,27 @@ void Reducer::InitializeGroups(
group.is_sparse_ = true; group.is_sparse_ = true;
} else { } else {
// process the dense gradient. // process the dense gradient.
all_length = InitializeDenseGroups(variable_indices_, &group); InitializeDenseGroups(variable_indices_, &group);
// Alloc the continuous space // Alloc the continuous space
auto tensor = group.dense_contents_.GetMutable<framework::LoDTensor>(); auto tensor = group.dense_contents_.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({all_length})) tensor->Resize(framework::make_ddim({group.all_length_}))
.mutable_data(place_, group.dtype_); .mutable_data(place_, group.dtype_);
} }
// Debug Message For Reducer
VLOG(3) << "the groups_[" << group_index << "] basic message:"; // map variables to this group by VariableLocator
VLOG(3) << "numul: " << all_length << " ;is_sparse: " << group.is_sparse_ size_t inside_group_index = 0;
<< " ;var number: " << group.variable_indices_.size(); for (const auto var_index : group_indices[group_index]) {
variable_locators_[var_index] = VariableLocator{
.group_index = group_index,
.inside_group_index = inside_group_index++,
};
}
group.variable_indices_ = std::move(variable_indices_);
groups_.emplace_back(std::move(group)); groups_.emplace_back(std::move(group));
// Debug Message For Reducer
VLOG(3) << "The Group[" << group_index << "]:";
VLOG(3) << groups_.back();
} }
} }
...@@ -192,11 +266,16 @@ void Reducer::PrepareForBackward() { ...@@ -192,11 +266,16 @@ void Reducer::PrepareForBackward() {
// counter is 0, it means that allreduce can be emitted, and // counter is 0, it means that allreduce can be emitted, and
// concat + allreduce + split is emitted in turn according to next_group_. // concat + allreduce + split is emitted in turn according to next_group_.
// 3, FinalizeBackward: after the end, synchronize each stream. // 3, FinalizeBackward: after the end, synchronize each stream.
void Reducer::AddDistHook(VariableWrapper *var_warpper, void Reducer::AddDistHook(VariableWrapper *var_warpper, size_t var_index) {
const VariableIndex &var_index) { const auto &var_locator = variable_locators_[var_index];
auto group_index = var_index.group_index; auto group_index = var_locator.group_index;
auto &group = groups_[group_index]; auto &group = groups_[group_index];
if (!has_rebuilt_group_) {
rebuild_vars_.push_back(vars_[var_index]);
rebuild_var_indices_.push_back(var_index);
}
if (!group.is_sparse_) { if (!group.is_sparse_) {
// Only dense_contents_ need memory copy // Only dense_contents_ need memory copy
MarkVariableReady(var_index, var_warpper); MarkVariableReady(var_index, var_warpper);
...@@ -211,21 +290,22 @@ void Reducer::AddDistHook(VariableWrapper *var_warpper, ...@@ -211,21 +290,22 @@ void Reducer::AddDistHook(VariableWrapper *var_warpper,
} }
} }
void Reducer::MarkVariableReady(const VariableIndex &var_index, void Reducer::MarkVariableReady(size_t var_index,
VariableWrapper *var_warpper) { VariableWrapper *var_warpper) {
auto group_index = var_index.group_index; const auto &var_locator = variable_locators_[var_index];
auto variable_index = var_index.inside_group_index; auto group_index = var_locator.group_index;
auto inside_group_index = var_locator.inside_group_index;
auto &group = groups_[group_index]; auto &group = groups_[group_index];
auto length = group.length_[variable_index]; auto length = group.length_[inside_group_index];
auto tensor = var_warpper->MutableVar()->GetMutable<framework::LoDTensor>(); auto tensor = var_warpper->MutableVar()->GetMutable<framework::LoDTensor>();
group.dense_tensors_[variable_index].ShareDataWith(*tensor).Resize( group.dense_tensors_[inside_group_index].ShareDataWith(*tensor).Resize(
{static_cast<int64_t>(length)}); {static_cast<int64_t>(length)});
} }
void Reducer::MarkGroupReady(size_t group_index) { void Reducer::MarkGroupReady(size_t group_index) {
if (group_index > next_group_) { if (group_index > next_group_) {
VLOG(3) << "Maybe it need adjust the order of group"; VLOG(3) << "It will adjust the order of group in next batch automatically";
return; return;
} }
...@@ -257,10 +337,31 @@ void Reducer::MarkGroupReady(size_t group_index) { ...@@ -257,10 +337,31 @@ void Reducer::MarkGroupReady(size_t group_index) {
} }
} }
std::vector<std::vector<size_t>> Reducer::RebuildGruops() {
std::reverse(rebuild_vars_.begin(), rebuild_vars_.end());
std::reverse(rebuild_var_indices_.begin(), rebuild_var_indices_.end());
auto rebuild_group_indices =
AssignGroupBySize(rebuild_vars_, is_sparse_gradient_, group_size_limits_,
rebuild_var_indices_);
has_rebuilt_group_ = true;
rebuild_vars_.clear();
rebuild_var_indices_.clear();
std::reverse(rebuild_group_indices.begin(), rebuild_group_indices.end());
return rebuild_group_indices;
}
void Reducer::FinalizeBackward() { void Reducer::FinalizeBackward() {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(comm_enent_.get(), comm_stream_)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(comm_enent_.get(), comm_stream_));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamWaitEvent(compute_stream_, comm_enent_.get(), 0)); cudaStreamWaitEvent(compute_stream_, comm_enent_.get(), 0));
if (!has_rebuilt_group_) {
VLOG(3) << "Start rebuilding the groups";
auto rebuild_group_indices = RebuildGruops();
auto rebuild_group_number = rebuild_group_indices.size();
group_indices_ = std::move(rebuild_group_indices);
CreateGroupEvents(rebuild_group_number);
InitializeGroups(group_indices_);
}
VLOG(3) << "In the batch, Reducer is finished..."; VLOG(3) << "In the batch, Reducer is finished...";
} }
...@@ -274,12 +375,28 @@ void Reducer::FinalizeBackward() { ...@@ -274,12 +375,28 @@ void Reducer::FinalizeBackward() {
std::vector<std::vector<size_t>> AssignGroupBySize( std::vector<std::vector<size_t>> AssignGroupBySize(
const std::vector<std::shared_ptr<imperative::VarBase>> &vars, const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
const std::vector<bool> &is_sparse_gradient, const std::vector<bool> &is_sparse_gradient,
const std::vector<size_t> &group_size_limits) { const std::vector<size_t> &group_size_limits,
const std::vector<int64_t> &tensor_indices) {
PADDLE_ENFORCE_EQ(vars.size(), is_sparse_gradient.size(), PADDLE_ENFORCE_EQ(vars.size(), is_sparse_gradient.size(),
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"vars len must be equal to is_sparse_gradient len, but " "vars len must be equal to is_sparse_gradient len, but "
"[%lu] != [%lu]", "[%lu] != [%lu]",
vars.size(), is_sparse_gradient.size())); vars.size(), is_sparse_gradient.size()));
auto check_perm = [](const std::vector<int64_t> &x) -> bool {
size_t len = x.size();
std::vector<size_t> cnt(len, 0);
for (size_t i = 0; i < len; ++i) {
if (x[i] >= static_cast<int64_t>(len) || x[i] < 0 || cnt[x[i]]) {
return false;
}
cnt[x[i]]++;
}
return true;
};
PADDLE_ENFORCE_EQ(true, check_perm(tensor_indices),
platform::errors::PreconditionNotMet(
"tensor_indices must be a permutation from 0 to %lu",
tensor_indices.size()));
// the return vector // the return vector
std::vector<std::vector<size_t>> res; std::vector<std::vector<size_t>> res;
...@@ -294,9 +411,15 @@ std::vector<std::vector<size_t>> AssignGroupBySize( ...@@ -294,9 +411,15 @@ std::vector<std::vector<size_t>> AssignGroupBySize(
for (size_t i = 0; i < vars.size(); ++i) { for (size_t i = 0; i < vars.size(); ++i) {
const auto &var = vars[i]; const auto &var = vars[i];
if (is_sparse_gradient[i]) {
size_t tensor_real_index = i;
if (!tensor_indices.empty()) {
tensor_real_index = tensor_indices[i];
}
if (is_sparse_gradient[tensor_real_index]) {
// we keep sparse var a single group // we keep sparse var a single group
res.push_back({i}); res.push_back({tensor_real_index});
continue; continue;
} }
...@@ -313,7 +436,7 @@ std::vector<std::vector<size_t>> AssignGroupBySize( ...@@ -313,7 +436,7 @@ std::vector<std::vector<size_t>> AssignGroupBySize(
<< " is not tensor or selected_rows, so skip it"; << " is not tensor or selected_rows, so skip it";
continue; continue;
} }
group_info.first.push_back(i); group_info.first.push_back(tensor_real_index);
group_info.second += framework::SizeOfType(var_dtype) * var_size; group_info.second += framework::SizeOfType(var_dtype) * var_size;
if (group_limit_index.find(var_dtype_str) == group_limit_index.end()) { if (group_limit_index.find(var_dtype_str) == group_limit_index.end()) {
...@@ -344,10 +467,12 @@ std::vector<std::vector<size_t>> AssignGroupBySize( ...@@ -344,10 +467,12 @@ std::vector<std::vector<size_t>> AssignGroupBySize(
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"AssignGroupBySize construct empty group, please check.")); "AssignGroupBySize construct empty group, please check."));
} }
if (tensor_indices.empty()) {
std::sort(res.begin(), res.end(), std::sort(res.begin(), res.end(),
[](const std::vector<size_t> &x, const std::vector<size_t> &y) { [](const std::vector<size_t> &x, const std::vector<size_t> &y) {
return x.front() < y.front(); return x.front() < y.front();
}); });
}
return res; return res;
} }
#endif #endif
......
...@@ -86,6 +86,8 @@ class Group { ...@@ -86,6 +86,8 @@ class Group {
std::vector<framework::Tensor> dense_tensors_; std::vector<framework::Tensor> dense_tensors_;
std::vector<size_t> length_; std::vector<size_t> length_;
int64_t all_length_{0};
// Global indices of participating variables in the group // Global indices of participating variables in the group
std::vector<size_t> variable_indices_; std::vector<size_t> variable_indices_;
...@@ -97,53 +99,15 @@ class Group { ...@@ -97,53 +99,15 @@ class Group {
framework::proto::VarType::Type dtype_; framework::proto::VarType::Type dtype_;
// context is used to select the stream for concat // context is used to select the stream for concat
void ConcatTensors(const platform::CUDADeviceContext& context) { void ConcatTensors(const platform::CUDADeviceContext& context);
switch (dtype_) {
case framework::proto::VarType::FP16:
ConcatTensorsForAllReduce<platform::float16>(context, dense_tensors_,
&dense_contents_);
break;
case framework::proto::VarType::FP32:
ConcatTensorsForAllReduce<float>(context, dense_tensors_,
&dense_contents_);
break;
case framework::proto::VarType::FP64:
ConcatTensorsForAllReduce<double>(context, dense_tensors_,
&dense_contents_);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it concats tensors for "
"allreduce.",
framework::DataTypeToString(dtype_)));
}
}
// context is used to select the stream for split // context is used to select the stream for split
void SplitTensors(const platform::CUDADeviceContext& context) { void SplitTensors(const platform::CUDADeviceContext& context);
switch (dtype_) {
case framework::proto::VarType::FP16: friend std::ostream& operator<<(std::ostream&, const Group&);
SplitTensorsForAllReduce<platform::float16>(context, &dense_contents_,
&dense_tensors_);
break;
case framework::proto::VarType::FP32:
SplitTensorsForAllReduce<float>(context, &dense_contents_,
&dense_tensors_);
break;
case framework::proto::VarType::FP64:
SplitTensorsForAllReduce<double>(context, &dense_contents_,
&dense_tensors_);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it splits tensors for "
"allreduce.",
framework::DataTypeToString(dtype_)));
}
}
}; };
struct VariableIndex { struct VariableLocator {
// record the index in groups_ // record the index in groups_
size_t group_index; size_t group_index;
size_t inside_group_index; size_t inside_group_index;
...@@ -155,22 +119,21 @@ class Reducer { ...@@ -155,22 +119,21 @@ class Reducer {
const std::vector<std::shared_ptr<imperative::VarBase>>& vars, const std::vector<std::shared_ptr<imperative::VarBase>>& vars,
const std::vector<std::vector<size_t>>& group_indices, const std::vector<std::vector<size_t>>& group_indices,
const std::vector<bool>& is_sparse_gradient, const std::vector<bool>& is_sparse_gradient,
std::shared_ptr<imperative::ParallelContext> parallel_ctx); std::shared_ptr<imperative::ParallelContext> parallel_ctx,
const std::vector<size_t>& group_size_limits);
virtual ~Reducer() {} virtual ~Reducer() {}
void InitializeGroups(const std::vector<std::vector<size_t>>& group_indices); void InitializeGroups(const std::vector<std::vector<size_t>>& group_indices);
int64_t InitializeDenseGroups(const std::vector<size_t>& variable_indices_, void InitializeDenseGroups(const std::vector<size_t>& variable_indices_,
Group* p_group); Group* p_group);
void PrepareForBackward(); void PrepareForBackward();
void AddDistHook(VariableWrapper* var_warpper, void AddDistHook(VariableWrapper* var_warpper, size_t var_index);
const VariableIndex& var_index);
void MarkVariableReady(const VariableIndex& var_index, void MarkVariableReady(size_t var_index, VariableWrapper* var_warpper);
VariableWrapper* var_warpper);
void MarkGroupReady(size_t group_index); void MarkGroupReady(size_t group_index);
...@@ -178,15 +141,21 @@ class Reducer { ...@@ -178,15 +141,21 @@ class Reducer {
void ReleaseReducer(); void ReleaseReducer();
std::vector<std::vector<size_t>> RebuildGruops();
void CreateGroupEvents(int group_num);
// Reducer Singleton // Reducer Singleton
static std::shared_ptr<Reducer> SetInstance( static std::shared_ptr<Reducer> SetInstance(
const std::vector<std::shared_ptr<imperative::VarBase>>& vars, const std::vector<std::shared_ptr<imperative::VarBase>>& vars,
const std::vector<std::vector<size_t>>& group_indices, const std::vector<std::vector<size_t>>& group_indices,
const std::vector<bool>& is_sparse_gradient, const std::vector<bool>& is_sparse_gradient,
std::shared_ptr<imperative::ParallelContext> parallel_ctx) { std::shared_ptr<imperative::ParallelContext> parallel_ctx,
const std::vector<size_t>& group_size_limits) {
if (NULL == s_instance_) { if (NULL == s_instance_) {
s_instance_.reset(new paddle::imperative::Reducer( s_instance_.reset(new paddle::imperative::Reducer(
vars, group_indices, is_sparse_gradient, parallel_ctx)); vars, group_indices, is_sparse_gradient, parallel_ctx,
group_size_limits));
} }
return s_instance_; return s_instance_;
} }
...@@ -208,17 +177,26 @@ class Reducer { ...@@ -208,17 +177,26 @@ class Reducer {
std::once_flag once_flag_; std::once_flag once_flag_;
std::vector<bool> is_sparse_gradient_; std::vector<bool> is_sparse_gradient_;
std::shared_ptr<imperative::ParallelContext> parallel_ctx_; std::shared_ptr<imperative::ParallelContext> parallel_ctx_;
std::vector<VariableLocator> variable_locators_;
// Following variables are to help sync stream
std::vector<std::shared_ptr<platform::CudaEventObject>> events_; std::vector<std::shared_ptr<platform::CudaEventObject>> events_;
std::shared_ptr<platform::CudaEventObject> comm_enent_; std::shared_ptr<platform::CudaEventObject> comm_enent_;
cudaStream_t compute_stream_; cudaStream_t compute_stream_;
cudaStream_t comm_stream_; cudaStream_t comm_stream_;
// Following variables are to help rebuild group
bool has_rebuilt_group_{false};
std::vector<std::shared_ptr<imperative::VarBase>> rebuild_vars_;
std::vector<int64_t> rebuild_var_indices_;
const std::vector<size_t> group_size_limits_;
}; };
std::vector<std::vector<size_t>> AssignGroupBySize( std::vector<std::vector<size_t>> AssignGroupBySize(
const std::vector<std::shared_ptr<imperative::VarBase>>& tensors, const std::vector<std::shared_ptr<imperative::VarBase>>& tensors,
const std::vector<bool>& is_sparse_gradient, const std::vector<bool>& is_sparse_gradient,
const std::vector<size_t>& group_size_limits); const std::vector<size_t>& group_size_limits,
const std::vector<int64_t>& tensor_indices = {});
#endif #endif
} // namespace imperative } // namespace imperative
......
...@@ -12,3 +12,7 @@ cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry ...@@ -12,3 +12,7 @@ cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry
cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place) cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place)
cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy) cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy)
cc_test(test_hooks SRCS test_hooks.cc DEPS tracer basic_engine layer proto_desc operator op_registry variable_helper mul_op elementwise_add_op memcpy) cc_test(test_hooks SRCS test_hooks.cc DEPS tracer basic_engine layer proto_desc operator op_registry variable_helper mul_op elementwise_add_op memcpy)
if (WITH_NCCL)
cc_test(test_group SRCS test_group.cc DEPS reducer concat_and_split memcpy)
endif()
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <ostream>
#include <sstream>
#include <string>
#include "glog/logging.h"
#include "gtest/gtest.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/imperative/reducer.h"
#endif
namespace paddle {
namespace imperative {
#if defined(PADDLE_WITH_NCCL)
TEST(TestGroup, TestPrintGroupMessage) {
Group group;
std::stringstream stream1, stream2;
stream1 << group;
ASSERT_STREQ(stream1.str().c_str(),
"numul: 0 ;is_sparse: 0 ;var number: 0\n[]\n");
std::vector<size_t> vars;
size_t vars_num = 102;
for (size_t i = 0; i < vars_num; ++i) {
vars.push_back(i);
}
group.variable_indices_ = vars;
group.all_length_ = 102;
group.is_sparse_ = false;
std::string head = "numul: 102 ;is_sparse: 0 ;var number: 102\n";
head = head + "[";
auto begin = vars.begin();
auto end = vars.end();
for (int i = 0; begin != end && i < 100; ++i, ++begin) {
if (i > 0) head += ' ';
head += std::to_string(*begin);
}
if (begin != end) {
head += " ...";
}
head += "]\n";
stream2 << group;
ASSERT_STREQ(stream2.str().c_str(), head.c_str());
}
#endif
} // namespace imperative
} // namespace paddle
...@@ -53,7 +53,13 @@ __global__ void GatherNdCUDAKernel(const T* input, const int* input_dims, ...@@ -53,7 +53,13 @@ __global__ void GatherNdCUDAKernel(const T* input, const int* input_dims,
int64_t temp = slice_size; int64_t temp = slice_size;
for (int64_t j = end_size - 1; j >= 0; --j) { for (int64_t j = end_size - 1; j >= 0; --j) {
auto index_value = indices[indices_i * end_size + j]; auto index_value = indices[indices_i * end_size + j];
assert(index_value >= 0 && index_value < input_dims[j]); PADDLE_ENFORCE(
index_value >= 0 && index_value < input_dims[j],
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d] and greater or equal to 0, but received [%d]",
input_dims[j], index_value);
gather_i += (index_value * temp); gather_i += (index_value * temp);
temp *= input_dims[j]; temp *= input_dims[j];
} }
......
...@@ -44,7 +44,6 @@ template <typename DeviceContext, typename T> ...@@ -44,7 +44,6 @@ template <typename DeviceContext, typename T>
void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output, void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output,
const std::vector<int>& reduce_dims, const std::vector<int>& reduce_dims,
const paddle::framework::ExecutionContext& ctx) { const paddle::framework::ExecutionContext& ctx) {
if (reduce_dims.empty()) return;
#ifdef __NVCC__ #ifdef __NVCC__
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>( TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
...@@ -573,47 +572,48 @@ class MatMulV2GradKernel : public framework::OpKernel<T> { ...@@ -573,47 +572,48 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
// So we should avoid the case in reality. // So we should avoid the case in reality.
VLOG(3) << "It need cost much time to reduce sum for the broadcast and " VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality"; "wastes the memory. So we should avoid the case in reality";
Tensor dx_help, dy_help;
if (transpose_x) { if (transpose_x) {
if (transpose_y) { if (transpose_y) {
// X'Y': dA = Y'G', dB = G'X' // X'Y': dA = Y'G', dB = G'X'
if (dx) if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims, dx, MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims,
true, true, ctx); &dx_help, true, true, ctx);
if (dy) if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims, dy, MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims,
true, true, ctx); &dy_help, true, true, ctx);
} else { } else {
// X'Y: dX = YG', dY = XG // X'Y: dX = YG', dY = XG
if (dx) if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims, dx, MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims,
false, true, ctx); &dx_help, false, true, ctx);
if (dy) if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims, dy, MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims,
false, false, ctx); &dy_help, false, false, ctx);
} }
} else { } else {
if (transpose_y) { if (transpose_y) {
// XY': dX = GY, dY = G'X // XY': dX = GY, dY = G'X
if (dx) if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims, dx, MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims,
false, false, ctx); &dx_help, false, false, ctx);
if (dy) if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims, dy, MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims,
true, false, ctx); &dy_help, true, false, ctx);
} else { } else {
// XY: dX = GY', dY = X'G // XY: dX = GY', dY = X'G
if (dx) if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims, dx, MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims,
false, true, ctx); &dx_help, false, true, ctx);
if (dy) if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims, dy, MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims,
true, false, ctx); &dy_help, true, false, ctx);
} }
} }
// get help dims // get help dims
const std::vector<std::int64_t> dx_help_dims = vectorize(dx->dims()); const std::vector<std::int64_t> dx_help_dims = vectorize(dx_help.dims());
const std::vector<std::int64_t> dy_help_dims = vectorize(dy->dims()); const std::vector<std::int64_t> dy_help_dims = vectorize(dy_help.dims());
std::vector<std::int64_t> dx_broadcast_dims(ndim); std::vector<std::int64_t> dx_broadcast_dims(ndim);
std::vector<std::int64_t> dy_broadcast_dims(ndim); std::vector<std::int64_t> dy_broadcast_dims(ndim);
...@@ -639,11 +639,21 @@ class MatMulV2GradKernel : public framework::OpKernel<T> { ...@@ -639,11 +639,21 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
} }
// reduce sum to get grad by ReduceSum // reduce sum to get grad by ReduceSum
if (dx) { if (dx) {
ReduceSumForMatmulGrad<DeviceContext, T>(dx, dx, dx_reduce_dims, ctx); if (dx_reduce_dims.empty()) {
*dx = std::move(dx_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&dx_help, dx, dx_reduce_dims,
ctx);
}
dx->Resize(x.dims()); dx->Resize(x.dims());
} }
if (dy) { if (dy) {
ReduceSumForMatmulGrad<DeviceContext, T>(dy, dy, dy_reduce_dims, ctx); if (dy_reduce_dims.empty()) {
*dy = std::move(dy_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&dy_help, dy, dy_reduce_dims,
ctx);
}
dy->Resize(y.dims()); dy->Resize(y.dims());
} }
} }
......
...@@ -1289,9 +1289,11 @@ void BindImperative(py::module *m_ptr) { ...@@ -1289,9 +1289,11 @@ void BindImperative(py::module *m_ptr) {
[](const std::vector<std::shared_ptr<imperative::VarBase>> &vars, [](const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
const std::vector<std::vector<size_t>> &group_indices, const std::vector<std::vector<size_t>> &group_indices,
const std::vector<bool> &is_sparse_gradient, const std::vector<bool> &is_sparse_gradient,
std::shared_ptr<imperative::ParallelContext> parallel_ctx) { std::shared_ptr<imperative::ParallelContext> parallel_ctx,
const std::vector<size_t> &group_size_limits) {
return imperative::Reducer::SetInstance( return imperative::Reducer::SetInstance(
vars, group_indices, is_sparse_gradient, parallel_ctx); vars, group_indices, is_sparse_gradient, parallel_ctx,
group_size_limits);
})) }))
.def("prepare_for_backward", &imperative::Reducer::PrepareForBackward, .def("prepare_for_backward", &imperative::Reducer::PrepareForBackward,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
...@@ -1299,6 +1301,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -1299,6 +1301,7 @@ void BindImperative(py::module *m_ptr) {
m.def("assign_group_by_size", &imperative::AssignGroupBySize, py::arg("vars"), m.def("assign_group_by_size", &imperative::AssignGroupBySize, py::arg("vars"),
py::arg("is_sparse_gradient"), py::arg("is_sparse_gradient"),
py::arg("group_size_limits") = std::vector<size_t>{25 * 1024 * 1024}, py::arg("group_size_limits") = std::vector<size_t>{25 * 1024 * 1024},
py::arg("tensor_indices") = std::vector<int64_t>{},
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
#endif #endif
} }
......
...@@ -18,7 +18,6 @@ from paddle.fluid.framework import Variable, set_flags, core ...@@ -18,7 +18,6 @@ from paddle.fluid.framework import Variable, set_flags, core
from paddle.fluid.wrapped_decorator import wrap_decorator from paddle.fluid.wrapped_decorator import wrap_decorator
import google.protobuf.text_format import google.protobuf.text_format
import google.protobuf import google.protobuf
from paddle.fluid.framework import dygraph_only
__all__ = ["DistributedStrategy"] __all__ = ["DistributedStrategy"]
......
...@@ -441,10 +441,11 @@ class DataParallel(layers.Layer): ...@@ -441,10 +441,11 @@ class DataParallel(layers.Layer):
"ParallelContext must be initialized before. You should use init_parallel_env() before" \ "ParallelContext must be initialized before. You should use init_parallel_env() before" \
"constructing the DataParallel." "constructing the DataParallel."
self._reducer = core.Reducer(trainable_parameters, self._reducer = core.Reducer(
list(reversed(self.group_indices)), trainable_parameters,
is_sparse_gradient, list(reversed(self.group_indices)), is_sparse_gradient,
parallel_helper.__parallel_ctx__clz__) parallel_helper.__parallel_ctx__clz__,
[self.last_comm_buffer_size, self.comm_buffer_size])
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
if self._strategy.nranks > 1: if self._strategy.nranks > 1:
......
...@@ -155,6 +155,30 @@ class TestDataParallelGroup(unittest.TestCase): ...@@ -155,6 +155,30 @@ class TestDataParallelGroup(unittest.TestCase):
var_list, [True, False, False, False, False, True], [200, 400]) var_list, [True, False, False, False, False, True], [200, 400])
self.assertEqual([[0], [1], [2], [3], [4], [5]], res) self.assertEqual([[0], [1], [2], [3], [4], [5]], res)
def test_construct_group8(self):
# one dtype & one limit capability & have tensor_indices
var_list = []
var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 25]))
var_list.append(
self.create_varbase(core.VarDesc.VarType.FP32, [2, 100]))
var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 50]))
var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 25]))
res = core.assign_group_by_size(var_list, [False, False, False, False],
[400], [3, 0, 1, 2])
self.assertEqual([[3, 0], [1], [2]], res)
def test_construct_group9(self):
# one dtype & one limit capability & have tensor_indices
var_list = []
var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 25]))
var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 25]))
var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 25]))
var_list.append(
self.create_varbase(core.VarDesc.VarType.FP32, [2, 1000]))
res = core.assign_group_by_size(var_list, [False, False, False, True],
[300], [1, 0, 2, 3])
self.assertEqual([[1, 0], [3], [2]], res)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -286,6 +286,30 @@ class TestMatMuklOp17(TestMatMulV2Op): ...@@ -286,6 +286,30 @@ class TestMatMuklOp17(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOpBroadcast1(TestMatMulV2Op):
"""
case 14_3
"""
def config(self):
self.x_shape = (3, 1, 10, 10)
self.y_shape = (1, 2, 10, 10)
self.trans_x = True
self.trans_y = True
class TestMatMuklOpBroadcast2(TestMatMulV2Op):
"""
case 14_4
"""
def config(self):
self.x_shape = (3, 1, 10, 10)
self.y_shape = (1, 2, 10, 10)
self.trans_x = False
self.trans_y = True
#--------------------test matmul fp16-------------------- #--------------------test matmul fp16--------------------
......
...@@ -140,6 +140,21 @@ def _map_path(url, root_dir): ...@@ -140,6 +140,21 @@ def _map_path(url, root_dir):
return osp.join(root_dir, fpath) return osp.join(root_dir, fpath)
def _get_unique_endpoints(trainer_endpoints):
# Sorting is to avoid different environmental variables for each card
trainer_endpoints.sort()
ips = set()
unique_endpoints = set()
for endpoint in trainer_endpoints:
ip = endpoint.split(":")[0]
if ip in ips:
continue
ips.add(ip)
unique_endpoints.add(endpoint)
logger.info("unique_endpoints {}".format(unique_endpoints))
return unique_endpoints
def get_path_from_url(url, root_dir, md5sum=None, check_exist=True): def get_path_from_url(url, root_dir, md5sum=None, check_exist=True):
""" Download from given url to root_dir. """ Download from given url to root_dir.
if file or directory specified by url is exists under if file or directory specified by url is exists under
...@@ -161,17 +176,20 @@ def get_path_from_url(url, root_dir, md5sum=None, check_exist=True): ...@@ -161,17 +176,20 @@ def get_path_from_url(url, root_dir, md5sum=None, check_exist=True):
assert is_url(url), "downloading from {} not a url".format(url) assert is_url(url), "downloading from {} not a url".format(url)
# parse path after download to decompress under root_dir # parse path after download to decompress under root_dir
fullpath = _map_path(url, root_dir) fullpath = _map_path(url, root_dir)
# Mainly used to solve the problem of downloading data from different
# machines in the case of multiple machines. Different ips will download
# data, and the same ip will only download data once.
unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:])
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum): if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
logger.info("Found {}".format(fullpath)) logger.info("Found {}".format(fullpath))
else: else:
if ParallelEnv().local_rank == 0: if ParallelEnv().current_endpoint in unique_endpoints:
fullpath = _download(url, root_dir, md5sum) fullpath = _download(url, root_dir, md5sum)
else: else:
while not os.path.exists(fullpath): while not os.path.exists(fullpath):
time.sleep(1) time.sleep(1)
if ParallelEnv().local_rank == 0: if ParallelEnv().current_endpoint in unique_endpoints:
if tarfile.is_tarfile(fullpath) or zipfile.is_zipfile(fullpath): if tarfile.is_tarfile(fullpath) or zipfile.is_zipfile(fullpath):
fullpath = _decompress(fullpath) fullpath = _decompress(fullpath)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册