From 2ef9e0e23c92571d43b65b155c799aa1dd858d4a Mon Sep 17 00:00:00 2001 From: ShenLiang Date: Wed, 9 Dec 2020 15:38:34 +0800 Subject: [PATCH] Rebuild group automatically in dynamic graph distributed (#29255) * add tensor_indices in AssignGroupBySize * add rebuild group in reducer --- paddle/fluid/imperative/reducer.cc | 241 +++++++++++++----- paddle/fluid/imperative/reducer.h | 84 +++--- paddle/fluid/imperative/tests/CMakeLists.txt | 4 + paddle/fluid/imperative/tests/test_group.cc | 66 +++++ paddle/fluid/pybind/imperative.cc | 7 +- .../fleet/base/distributed_strategy.py | 1 - python/paddle/fluid/dygraph/parallel.py | 9 +- .../tests/unittests/test_imperative_group.py | 24 ++ 8 files changed, 318 insertions(+), 118 deletions(-) create mode 100644 paddle/fluid/imperative/tests/test_group.cc diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 3f0703f05a8..54a2b647d42 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -20,47 +20,98 @@ namespace imperative { #if defined(PADDLE_WITH_NCCL) std::shared_ptr Reducer::s_instance_ = NULL; +// context is used to select the stream for concat +void Group::ConcatTensors(const platform::CUDADeviceContext &context) { + switch (dtype_) { + case framework::proto::VarType::FP16: + ConcatTensorsForAllReduce(context, dense_tensors_, + &dense_contents_); + break; + case framework::proto::VarType::FP32: + ConcatTensorsForAllReduce(context, dense_tensors_, + &dense_contents_); + break; + case framework::proto::VarType::FP64: + ConcatTensorsForAllReduce(context, dense_tensors_, + &dense_contents_); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type (%s) is not supported when it concats tensors for " + "allreduce.", + framework::DataTypeToString(dtype_))); + } +} + +// context is used to select the stream for split +void Group::SplitTensors(const platform::CUDADeviceContext &context) { + switch (dtype_) { + case framework::proto::VarType::FP16: + SplitTensorsForAllReduce(context, &dense_contents_, + &dense_tensors_); + break; + case framework::proto::VarType::FP32: + SplitTensorsForAllReduce(context, &dense_contents_, + &dense_tensors_); + break; + case framework::proto::VarType::FP64: + SplitTensorsForAllReduce(context, &dense_contents_, + &dense_tensors_); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type (%s) is not supported when it splits tensors for " + "allreduce.", + framework::DataTypeToString(dtype_))); + } +} + +std::ostream &operator<<(std::ostream &out, const Group &group) { + const auto &vars = group.variable_indices_; + out << "numul: " << group.all_length_ << " ;is_sparse: " << group.is_sparse_ + << " ;var number: " << vars.size() << "\n"; + auto begin = vars.begin(); + auto end = vars.end(); + out << "["; + for (int i = 0; begin != end && i < 100; ++i, ++begin) { + if (i > 0) out << ' '; + out << *begin; + } + if (begin != end) { + out << " ..."; + } + out << "]\n"; + return out; +} + Reducer::Reducer(const std::vector> &vars, const std::vector> &group_indices, const std::vector &is_sparse_gradient, - std::shared_ptr parallel_ctx) + std::shared_ptr parallel_ctx, + const std::vector &group_size_limits) : vars_(vars), group_indices_(group_indices), is_sparse_gradient_(is_sparse_gradient), - parallel_ctx_(parallel_ctx) { + parallel_ctx_(parallel_ctx), + group_size_limits_(group_size_limits) { VLOG(3) << "Start construct the Reducer ..."; // initialize groups InitializeGroups(group_indices); - - { - for (size_t group_index = 0; group_index < group_indices.size(); - ++group_index) { - for (size_t var_index = 0; var_index < group_indices[group_index].size(); - ++var_index) { - size_t global_var_index = group_indices[group_index][var_index]; - const auto variable_index = VariableIndex{ - .group_index = group_index, .inside_group_index = var_index, - }; - VLOG(3) << "add hook for var[" << vars_[global_var_index]->GradVarName() - << "], it's in group [" << group_index << "]"; - vars_[global_var_index]->SharedVar()->AddGradVarLeafBackwardHook( - std::unique_ptr( - new LambdaGradAccumulatorPostHook([=](VariableWrapper *grad) { - this->AddDistHook(grad, variable_index); - }))); - } - } + for (size_t global_var_index = 0; global_var_index < vars_.size(); + ++global_var_index) { + vars_[global_var_index]->SharedVar()->AddGradVarLeafBackwardHook( + std::unique_ptr( + new LambdaGradAccumulatorPostHook([=](VariableWrapper *grad) { + this->AddDistHook(grad, global_var_index); + }))); } - + // create streams compute_stream_ = static_cast( platform::DeviceContextPool::Instance().Get(place_)) ->stream(); comm_stream_ = platform::NCCLCommContext::Instance().Get(0, place_)->stream(); - events_.resize(group_indices.size()); - for (auto &event : events_) { - event = platform::CudaEventResourcePool::Instance().New( - BOOST_GET_CONST(platform::CUDAPlace, place_).device); - } + // create events + CreateGroupEvents(group_indices.size()); comm_enent_ = platform::CudaEventResourcePool::Instance().New( BOOST_GET_CONST(platform::CUDAPlace, place_).device); @@ -76,7 +127,20 @@ void Reducer::ReleaseReducer() { comm_enent_.reset(); } -int64_t Reducer::InitializeDenseGroups( +void Reducer::CreateGroupEvents(int group_num) { + // release old events + for (auto &event : events_) { + event.reset(); + } + events_.clear(); + events_.resize(group_num); + for (auto &event : events_) { + event = platform::CudaEventResourcePool::Instance().New( + BOOST_GET_CONST(platform::CUDAPlace, place_).device); + } +} + +void Reducer::InitializeDenseGroups( const std::vector &variable_indices_, Group *p_group) { int64_t all_length = 0; for (size_t index = 0; index < variable_indices_.size(); ++index) { @@ -85,18 +149,18 @@ int64_t Reducer::InitializeDenseGroups( const auto var_name = var->Name(); PADDLE_ENFORCE_EQ(is_sparse_gradient_[variable_index], false, platform::errors::PreconditionNotMet( - "Tensor `%s`'s GRAD must be LoDTensor, but received " + "Tensor %s's GRAD must be LoDTensor, but received " "GRAD is SelectedRows", var_name)); auto lod_tensor = var->MutableVar()->GetMutable(); PADDLE_ENFORCE_EQ(lod_tensor->IsInitialized(), true, platform::errors::PreconditionNotMet( - "Tensor `%s` is not initialized.", var_name)); + "Tensor %s is not initialized.", var_name)); auto size = lod_tensor->numel(); PADDLE_ENFORCE_GT( size, 0, platform::errors::PreconditionNotMet( - "The number of tensor `%s`'s elements is 0.", var_name)); + "The number of tensor %s's elements is 0.", var_name)); all_length += size; p_group->length_.push_back(size); @@ -124,7 +188,7 @@ int64_t Reducer::InitializeDenseGroups( place_ = place; } } - return all_length; + p_group->all_length_ = all_length; } // Each parameter will be initialized according to the group information. @@ -137,6 +201,8 @@ void Reducer::InitializeGroups( // clear the group groups_.clear(); groups_.reserve(group_indices.size()); + variable_locators_.clear(); + variable_locators_.resize(vars_.size()); auto group_nums = group_indices.size(); for (size_t group_index = 0; group_index < group_nums; ++group_index) { @@ -144,10 +210,8 @@ void Reducer::InitializeGroups( PADDLE_ENFORCE_GT( variable_indices_.size(), 0, platform::errors::PreconditionNotMet( - "The number of group_index[`%d`]'s elements is 0.", group_index)); + "The number of group[%d]'s elements is 0.", group_index)); Group group; - group.variable_indices_ = variable_indices_; - int64_t all_length = 0; // It's just for check the sparse or dense auto first_varbase = vars_[variable_indices_.front()]; @@ -159,17 +223,27 @@ void Reducer::InitializeGroups( group.is_sparse_ = true; } else { // process the dense gradient. - all_length = InitializeDenseGroups(variable_indices_, &group); + InitializeDenseGroups(variable_indices_, &group); // Alloc the continuous space auto tensor = group.dense_contents_.GetMutable(); - tensor->Resize(framework::make_ddim({all_length})) + tensor->Resize(framework::make_ddim({group.all_length_})) .mutable_data(place_, group.dtype_); } - // Debug Message For Reducer - VLOG(3) << "the groups_[" << group_index << "] basic message:"; - VLOG(3) << "numul: " << all_length << " ;is_sparse: " << group.is_sparse_ - << " ;var number: " << group.variable_indices_.size(); + + // map variables to this group by VariableLocator + size_t inside_group_index = 0; + for (const auto var_index : group_indices[group_index]) { + variable_locators_[var_index] = VariableLocator{ + .group_index = group_index, + .inside_group_index = inside_group_index++, + }; + } + group.variable_indices_ = std::move(variable_indices_); groups_.emplace_back(std::move(group)); + + // Debug Message For Reducer + VLOG(3) << "The Group[" << group_index << "]:"; + VLOG(3) << groups_.back(); } } @@ -192,11 +266,16 @@ void Reducer::PrepareForBackward() { // counter is 0, it means that allreduce can be emitted, and // concat + allreduce + split is emitted in turn according to next_group_. // 3, FinalizeBackward: after the end, synchronize each stream. -void Reducer::AddDistHook(VariableWrapper *var_warpper, - const VariableIndex &var_index) { - auto group_index = var_index.group_index; +void Reducer::AddDistHook(VariableWrapper *var_warpper, size_t var_index) { + const auto &var_locator = variable_locators_[var_index]; + auto group_index = var_locator.group_index; auto &group = groups_[group_index]; + if (!has_rebuilt_group_) { + rebuild_vars_.push_back(vars_[var_index]); + rebuild_var_indices_.push_back(var_index); + } + if (!group.is_sparse_) { // Only dense_contents_ need memory copy MarkVariableReady(var_index, var_warpper); @@ -211,21 +290,22 @@ void Reducer::AddDistHook(VariableWrapper *var_warpper, } } -void Reducer::MarkVariableReady(const VariableIndex &var_index, +void Reducer::MarkVariableReady(size_t var_index, VariableWrapper *var_warpper) { - auto group_index = var_index.group_index; - auto variable_index = var_index.inside_group_index; + const auto &var_locator = variable_locators_[var_index]; + auto group_index = var_locator.group_index; + auto inside_group_index = var_locator.inside_group_index; auto &group = groups_[group_index]; - auto length = group.length_[variable_index]; + auto length = group.length_[inside_group_index]; auto tensor = var_warpper->MutableVar()->GetMutable(); - group.dense_tensors_[variable_index].ShareDataWith(*tensor).Resize( + group.dense_tensors_[inside_group_index].ShareDataWith(*tensor).Resize( {static_cast(length)}); } void Reducer::MarkGroupReady(size_t group_index) { if (group_index > next_group_) { - VLOG(3) << "Maybe it need adjust the order of group"; + VLOG(3) << "It will adjust the order of group in next batch automatically"; return; } @@ -257,10 +337,31 @@ void Reducer::MarkGroupReady(size_t group_index) { } } +std::vector> Reducer::RebuildGruops() { + std::reverse(rebuild_vars_.begin(), rebuild_vars_.end()); + std::reverse(rebuild_var_indices_.begin(), rebuild_var_indices_.end()); + auto rebuild_group_indices = + AssignGroupBySize(rebuild_vars_, is_sparse_gradient_, group_size_limits_, + rebuild_var_indices_); + has_rebuilt_group_ = true; + rebuild_vars_.clear(); + rebuild_var_indices_.clear(); + std::reverse(rebuild_group_indices.begin(), rebuild_group_indices.end()); + return rebuild_group_indices; +} + void Reducer::FinalizeBackward() { PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(comm_enent_.get(), comm_stream_)); PADDLE_ENFORCE_CUDA_SUCCESS( cudaStreamWaitEvent(compute_stream_, comm_enent_.get(), 0)); + if (!has_rebuilt_group_) { + VLOG(3) << "Start rebuilding the groups"; + auto rebuild_group_indices = RebuildGruops(); + auto rebuild_group_number = rebuild_group_indices.size(); + group_indices_ = std::move(rebuild_group_indices); + CreateGroupEvents(rebuild_group_number); + InitializeGroups(group_indices_); + } VLOG(3) << "In the batch, Reducer is finished..."; } @@ -274,12 +375,28 @@ void Reducer::FinalizeBackward() { std::vector> AssignGroupBySize( const std::vector> &vars, const std::vector &is_sparse_gradient, - const std::vector &group_size_limits) { + const std::vector &group_size_limits, + const std::vector &tensor_indices) { PADDLE_ENFORCE_EQ(vars.size(), is_sparse_gradient.size(), platform::errors::PreconditionNotMet( "vars len must be equal to is_sparse_gradient len, but " "[%lu] != [%lu]", vars.size(), is_sparse_gradient.size())); + auto check_perm = [](const std::vector &x) -> bool { + size_t len = x.size(); + std::vector cnt(len, 0); + for (size_t i = 0; i < len; ++i) { + if (x[i] >= static_cast(len) || x[i] < 0 || cnt[x[i]]) { + return false; + } + cnt[x[i]]++; + } + return true; + }; + PADDLE_ENFORCE_EQ(true, check_perm(tensor_indices), + platform::errors::PreconditionNotMet( + "tensor_indices must be a permutation from 0 to %lu", + tensor_indices.size())); // the return vector std::vector> res; @@ -294,9 +411,15 @@ std::vector> AssignGroupBySize( for (size_t i = 0; i < vars.size(); ++i) { const auto &var = vars[i]; - if (is_sparse_gradient[i]) { + + size_t tensor_real_index = i; + if (!tensor_indices.empty()) { + tensor_real_index = tensor_indices[i]; + } + + if (is_sparse_gradient[tensor_real_index]) { // we keep sparse var a single group - res.push_back({i}); + res.push_back({tensor_real_index}); continue; } @@ -313,7 +436,7 @@ std::vector> AssignGroupBySize( << " is not tensor or selected_rows, so skip it"; continue; } - group_info.first.push_back(i); + group_info.first.push_back(tensor_real_index); group_info.second += framework::SizeOfType(var_dtype) * var_size; if (group_limit_index.find(var_dtype_str) == group_limit_index.end()) { @@ -344,10 +467,12 @@ std::vector> AssignGroupBySize( platform::errors::PreconditionNotMet( "AssignGroupBySize construct empty group, please check.")); } - std::sort(res.begin(), res.end(), - [](const std::vector &x, const std::vector &y) { - return x.front() < y.front(); - }); + if (tensor_indices.empty()) { + std::sort(res.begin(), res.end(), + [](const std::vector &x, const std::vector &y) { + return x.front() < y.front(); + }); + } return res; } #endif diff --git a/paddle/fluid/imperative/reducer.h b/paddle/fluid/imperative/reducer.h index 5e38f8abb18..3e65685d5c2 100644 --- a/paddle/fluid/imperative/reducer.h +++ b/paddle/fluid/imperative/reducer.h @@ -86,6 +86,8 @@ class Group { std::vector dense_tensors_; std::vector length_; + + int64_t all_length_{0}; // Global indices of participating variables in the group std::vector variable_indices_; @@ -97,53 +99,15 @@ class Group { framework::proto::VarType::Type dtype_; // context is used to select the stream for concat - void ConcatTensors(const platform::CUDADeviceContext& context) { - switch (dtype_) { - case framework::proto::VarType::FP16: - ConcatTensorsForAllReduce(context, dense_tensors_, - &dense_contents_); - break; - case framework::proto::VarType::FP32: - ConcatTensorsForAllReduce(context, dense_tensors_, - &dense_contents_); - break; - case framework::proto::VarType::FP64: - ConcatTensorsForAllReduce(context, dense_tensors_, - &dense_contents_); - break; - default: - PADDLE_THROW(platform::errors::Unimplemented( - "Data type (%s) is not supported when it concats tensors for " - "allreduce.", - framework::DataTypeToString(dtype_))); - } - } + void ConcatTensors(const platform::CUDADeviceContext& context); // context is used to select the stream for split - void SplitTensors(const platform::CUDADeviceContext& context) { - switch (dtype_) { - case framework::proto::VarType::FP16: - SplitTensorsForAllReduce(context, &dense_contents_, - &dense_tensors_); - break; - case framework::proto::VarType::FP32: - SplitTensorsForAllReduce(context, &dense_contents_, - &dense_tensors_); - break; - case framework::proto::VarType::FP64: - SplitTensorsForAllReduce(context, &dense_contents_, - &dense_tensors_); - break; - default: - PADDLE_THROW(platform::errors::Unimplemented( - "Data type (%s) is not supported when it splits tensors for " - "allreduce.", - framework::DataTypeToString(dtype_))); - } - } + void SplitTensors(const platform::CUDADeviceContext& context); + + friend std::ostream& operator<<(std::ostream&, const Group&); }; -struct VariableIndex { +struct VariableLocator { // record the index in groups_ size_t group_index; size_t inside_group_index; @@ -155,22 +119,21 @@ class Reducer { const std::vector>& vars, const std::vector>& group_indices, const std::vector& is_sparse_gradient, - std::shared_ptr parallel_ctx); + std::shared_ptr parallel_ctx, + const std::vector& group_size_limits); virtual ~Reducer() {} void InitializeGroups(const std::vector>& group_indices); - int64_t InitializeDenseGroups(const std::vector& variable_indices_, - Group* p_group); + void InitializeDenseGroups(const std::vector& variable_indices_, + Group* p_group); void PrepareForBackward(); - void AddDistHook(VariableWrapper* var_warpper, - const VariableIndex& var_index); + void AddDistHook(VariableWrapper* var_warpper, size_t var_index); - void MarkVariableReady(const VariableIndex& var_index, - VariableWrapper* var_warpper); + void MarkVariableReady(size_t var_index, VariableWrapper* var_warpper); void MarkGroupReady(size_t group_index); @@ -178,15 +141,21 @@ class Reducer { void ReleaseReducer(); + std::vector> RebuildGruops(); + + void CreateGroupEvents(int group_num); + // Reducer Singleton static std::shared_ptr SetInstance( const std::vector>& vars, const std::vector>& group_indices, const std::vector& is_sparse_gradient, - std::shared_ptr parallel_ctx) { + std::shared_ptr parallel_ctx, + const std::vector& group_size_limits) { if (NULL == s_instance_) { s_instance_.reset(new paddle::imperative::Reducer( - vars, group_indices, is_sparse_gradient, parallel_ctx)); + vars, group_indices, is_sparse_gradient, parallel_ctx, + group_size_limits)); } return s_instance_; } @@ -208,17 +177,26 @@ class Reducer { std::once_flag once_flag_; std::vector is_sparse_gradient_; std::shared_ptr parallel_ctx_; + std::vector variable_locators_; + // Following variables are to help sync stream std::vector> events_; std::shared_ptr comm_enent_; cudaStream_t compute_stream_; cudaStream_t comm_stream_; + + // Following variables are to help rebuild group + bool has_rebuilt_group_{false}; + std::vector> rebuild_vars_; + std::vector rebuild_var_indices_; + const std::vector group_size_limits_; }; std::vector> AssignGroupBySize( const std::vector>& tensors, const std::vector& is_sparse_gradient, - const std::vector& group_size_limits); + const std::vector& group_size_limits, + const std::vector& tensor_indices = {}); #endif } // namespace imperative diff --git a/paddle/fluid/imperative/tests/CMakeLists.txt b/paddle/fluid/imperative/tests/CMakeLists.txt index 782f6dad58d..b236ece541e 100644 --- a/paddle/fluid/imperative/tests/CMakeLists.txt +++ b/paddle/fluid/imperative/tests/CMakeLists.txt @@ -12,3 +12,7 @@ cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place) cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy) cc_test(test_hooks SRCS test_hooks.cc DEPS tracer basic_engine layer proto_desc operator op_registry variable_helper mul_op elementwise_add_op memcpy) + +if (WITH_NCCL) +cc_test(test_group SRCS test_group.cc DEPS reducer concat_and_split memcpy) +endif() diff --git a/paddle/fluid/imperative/tests/test_group.cc b/paddle/fluid/imperative/tests/test_group.cc new file mode 100644 index 00000000000..2e967d296d8 --- /dev/null +++ b/paddle/fluid/imperative/tests/test_group.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "glog/logging.h" +#include "gtest/gtest.h" + +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/imperative/reducer.h" +#endif + +namespace paddle { +namespace imperative { + +#if defined(PADDLE_WITH_NCCL) +TEST(TestGroup, TestPrintGroupMessage) { + Group group; + std::stringstream stream1, stream2; + stream1 << group; + ASSERT_STREQ(stream1.str().c_str(), + "numul: 0 ;is_sparse: 0 ;var number: 0\n[]\n"); + + std::vector vars; + size_t vars_num = 102; + for (size_t i = 0; i < vars_num; ++i) { + vars.push_back(i); + } + group.variable_indices_ = vars; + group.all_length_ = 102; + group.is_sparse_ = false; + + std::string head = "numul: 102 ;is_sparse: 0 ;var number: 102\n"; + head = head + "["; + auto begin = vars.begin(); + auto end = vars.end(); + for (int i = 0; begin != end && i < 100; ++i, ++begin) { + if (i > 0) head += ' '; + head += std::to_string(*begin); + } + if (begin != end) { + head += " ..."; + } + head += "]\n"; + stream2 << group; + ASSERT_STREQ(stream2.str().c_str(), head.c_str()); +} + +#endif + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 4a4f55cf57b..7a48ffa82a4 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -1289,9 +1289,11 @@ void BindImperative(py::module *m_ptr) { [](const std::vector> &vars, const std::vector> &group_indices, const std::vector &is_sparse_gradient, - std::shared_ptr parallel_ctx) { + std::shared_ptr parallel_ctx, + const std::vector &group_size_limits) { return imperative::Reducer::SetInstance( - vars, group_indices, is_sparse_gradient, parallel_ctx); + vars, group_indices, is_sparse_gradient, parallel_ctx, + group_size_limits); })) .def("prepare_for_backward", &imperative::Reducer::PrepareForBackward, py::call_guard()); @@ -1299,6 +1301,7 @@ void BindImperative(py::module *m_ptr) { m.def("assign_group_by_size", &imperative::AssignGroupBySize, py::arg("vars"), py::arg("is_sparse_gradient"), py::arg("group_size_limits") = std::vector{25 * 1024 * 1024}, + py::arg("tensor_indices") = std::vector{}, py::call_guard()); #endif } diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 98b6bc0cc89..658143d0a22 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -18,7 +18,6 @@ from paddle.fluid.framework import Variable, set_flags, core from paddle.fluid.wrapped_decorator import wrap_decorator import google.protobuf.text_format import google.protobuf -from paddle.fluid.framework import dygraph_only __all__ = ["DistributedStrategy"] diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 77a0308a533..731a9f809d8 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -441,10 +441,11 @@ class DataParallel(layers.Layer): "ParallelContext must be initialized before. You should use init_parallel_env() before" \ "constructing the DataParallel." - self._reducer = core.Reducer(trainable_parameters, - list(reversed(self.group_indices)), - is_sparse_gradient, - parallel_helper.__parallel_ctx__clz__) + self._reducer = core.Reducer( + trainable_parameters, + list(reversed(self.group_indices)), is_sparse_gradient, + parallel_helper.__parallel_ctx__clz__, + [self.last_comm_buffer_size, self.comm_buffer_size]) def forward(self, *inputs, **kwargs): if self._strategy.nranks > 1: diff --git a/python/paddle/fluid/tests/unittests/test_imperative_group.py b/python/paddle/fluid/tests/unittests/test_imperative_group.py index 299efa6d9c1..f9635809651 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_group.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_group.py @@ -155,6 +155,30 @@ class TestDataParallelGroup(unittest.TestCase): var_list, [True, False, False, False, False, True], [200, 400]) self.assertEqual([[0], [1], [2], [3], [4], [5]], res) + def test_construct_group8(self): + # one dtype & one limit capability & have tensor_indices + var_list = [] + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 25])) + var_list.append( + self.create_varbase(core.VarDesc.VarType.FP32, [2, 100])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 25])) + res = core.assign_group_by_size(var_list, [False, False, False, False], + [400], [3, 0, 1, 2]) + self.assertEqual([[3, 0], [1], [2]], res) + + def test_construct_group9(self): + # one dtype & one limit capability & have tensor_indices + var_list = [] + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 25])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 25])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 25])) + var_list.append( + self.create_varbase(core.VarDesc.VarType.FP32, [2, 1000])) + res = core.assign_group_by_size(var_list, [False, False, False, True], + [300], [1, 0, 2, 3]) + self.assertEqual([[1, 0], [3], [2]], res) + if __name__ == '__main__': unittest.main() -- GitLab