From 1a32391c66484a3466bc1a5595e97816097a60f5 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Tue, 15 Mar 2022 10:49:55 +0800 Subject: [PATCH] [Dygraph] Refactoring of reducer in DataParallel (#40389) * refactor reducer * modify cmakelists * solve conflicts * rename group and update process_group * fix bugs of ProcessGroupNCCL * modify for CIs * refactoring reducer --- .../distributed/collective/CMakeLists.txt | 3 +- .../collective/ProcessGroupNCCL.cc | 4 +- .../fluid/distributed/collective/reducer.cc | 426 +++++++++++++++++- paddle/fluid/distributed/collective/reducer.h | 99 +++- paddle/fluid/pybind/distributed_py.cc | 23 + python/paddle/fluid/dygraph/parallel.py | 63 ++- ...llel_dygraph_dataparallel_in_eager_mode.py | 127 ++++++ .../test_parallel_dygraph_dataparallel.py | 5 + python/paddle/optimizer/optimizer.py | 9 +- 9 files changed, 736 insertions(+), 23 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/parallel_dygraph_dataparallel_in_eager_mode.py diff --git a/paddle/fluid/distributed/collective/CMakeLists.txt b/paddle/fluid/distributed/collective/CMakeLists.txt index f88c993d85..3fca45cc06 100644 --- a/paddle/fluid/distributed/collective/CMakeLists.txt +++ b/paddle/fluid/distributed/collective/CMakeLists.txt @@ -1,8 +1,9 @@ cc_library(processgroup SRCS ProcessGroup.cc DEPS phi phi_api eager_api) +cc_library(eager_reducer SRCS reducer.cc DEPS eager_api processgroup phi phi_api) + if (WITH_DISTRIBUTE) cc_library(processgroup_gloo SRCS ProcessGroupGloo.cc DEPS phi phi_api eager_api gloo_wrapper) endif() -cc_library(eager_reducer SRCS reducer.cc DEPS eager_api processgroup) if(WITH_NCCL) cc_library(processgroup_nccl SRCS ProcessGroupNCCL.cc DEPS place cuda_stream enforce collective_helper device_context phi phi_api eager_api) diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 67715f410d..7f21bcee87 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -88,8 +88,8 @@ void SyncDefaultStream( for (size_t i = 0; i < places.size(); ++i) { auto* default_ctx = static_cast( platform::DeviceContextPool::Instance().Get(places[i])); - ncclEvents[i].Record(*dev_ctx[i]); - ncclEvents[i].Block(*default_ctx); + ncclEvents[i].Record(*default_ctx); + ncclEvents[i].Block(*dev_ctx[i]); } } diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index 59f3ea3b0a..5533f3f4cb 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/fluid/distributed/collective/reducer.h" -#include "paddle/phi/common/data_type.h" namespace paddle { namespace distributed { @@ -127,5 +126,430 @@ std::vector> Eager_AssignGroupBySize( return res; } +template +static void ConcatTensorsForAllReduce( + const DeviceContext &context, + const std::vector &dense_tensors_, + Tensor *p_dense_contents) { + operators::math::ConcatFunctor concat_functor_; + concat_functor_( + context, dense_tensors_, 0, + std::dynamic_pointer_cast(p_dense_contents->impl()) + .get()); +} + +template +static void SplitTensorsForAllReduce( + const DeviceContext &context, Tensor *p_dense_contents, + std::vector *p_dense_tensors) { + auto *in = + std::dynamic_pointer_cast(p_dense_contents->impl()) + .get(); + std::vector outs; + std::vector shape_refer; + + outs.reserve(p_dense_tensors->size()); + shape_refer.reserve(p_dense_tensors->size()); + + for (auto &tensor : *p_dense_tensors) { + outs.emplace_back(&tensor); + shape_refer.emplace_back(&tensor); + } + + operators::math::SplitFunctor split_functor_; + split_functor_(context, *in, shape_refer, 0, &outs); +} + +// context is used to select the stream for concat +template +static void ConcatTensorsWithType( + const DeviceContext &context, + const std::vector &dense_tensors_, + Tensor *p_dense_contents, phi::DataType type) { + switch (type) { + case phi::DataType::FLOAT16: + ConcatTensorsForAllReduce( + context, dense_tensors_, p_dense_contents); + break; + case phi::DataType::FLOAT32: + ConcatTensorsForAllReduce(context, dense_tensors_, + p_dense_contents); + break; + case phi::DataType::FLOAT64: + ConcatTensorsForAllReduce(context, dense_tensors_, + p_dense_contents); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type (%s) is not supported when it concats tensors for " + "allreduce.", + type)); + } +} + +// context is used to select the stream for split +template +static void SplitTensorsWithType(const DeviceContext &context, + Tensor *p_dense_contents, + std::vector *p_dense_tensors, + phi::DataType type) { + switch (type) { + case phi::DataType::FLOAT16: + SplitTensorsForAllReduce( + context, p_dense_contents, p_dense_tensors); + break; + case phi::DataType::FLOAT32: + SplitTensorsForAllReduce(context, p_dense_contents, + p_dense_tensors); + break; + case phi::DataType::FLOAT64: + SplitTensorsForAllReduce(context, p_dense_contents, + p_dense_tensors); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type (%s) is not supported when it splits tensors for " + "allreduce.", + type)); + } +} + +void EagerGroup::ConcatTensors(const platform::Place &place) { + if (platform::is_gpu_place(place)) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto *default_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + ConcatTensorsWithType(*default_ctx, dense_tensors_, &dense_contents_, + dtype_); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't concat grad tensors since it's not compiled with NCCL," + "Please recompile or reinstall Paddle with NCCL support.")); +#endif + } else if (platform::is_cpu_place(place)) { + auto *default_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + ConcatTensorsWithType(*default_ctx, dense_tensors_, &dense_contents_, + dtype_); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Concat grad tensor not supported on place (%s)", place)); + } +} + +void EagerGroup::SplitTensors(const platform::Place &place) { + if (platform::is_gpu_place(place)) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto *default_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + SplitTensorsWithType(*default_ctx, &dense_contents_, &dense_tensors_, + dtype_); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't split grad tensor since it's not compiled with NCCL," + "Please recompile or reinstall Paddle with NCCL support.")); +#endif + } else if (platform::is_cpu_place(place)) { + auto *default_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + SplitTensorsWithType(*default_ctx, &dense_contents_, &dense_tensors_, + dtype_); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Split grad tensor not supported on place (%s)", place)); + } +} + +EagerReducer::EagerReducer( + const std::vector tensors, + const std::vector> &group_indices, + const std::vector &is_sparse_gradient, + std::shared_ptr process_group, + const std::vector &group_size_limits, bool find_unused_parameters) + : tensors_(tensors), + group_indices_(group_indices), + is_sparse_gradient_(is_sparse_gradient), + process_group_(process_group), + group_size_limits_(group_size_limits), + find_unused_vars_each_step_(find_unused_parameters) { + VLOG(3) << "Start construct the Reducer ..."; + + nranks_ = process_group_->GetSize(); + + // initialize groups + InitializeGroups(group_indices); + + for (size_t global_var_index = 0; global_var_index < tensors_.size(); + ++global_var_index) { + auto tensor = tensors_[global_var_index]; + auto reduce_hook = [=](void) -> void { + this->AddDistHook(global_var_index); + }; + + const auto &grad_node = GetGradNodeFromTensor(&tensor); + + PADDLE_ENFORCE( + grad_node.get() != nullptr, + paddle::platform::errors::Fatal("Detected NULL grad_node," + "Leaf tensor should have had grad_node " + "with type: GradNodeAccumulation")); + const auto &accumulation_grad_node = + std::dynamic_pointer_cast(grad_node); + accumulation_grad_node->RegisterReduceHook( + std::make_shared(reduce_hook)); + } + + vars_marked_ready_.resize(tensors_.size(), false); + local_used_vars_.resize(tensors_.size(), 0); +} + +std::shared_ptr EagerReducer::GetGradNodeFromTensor( + Tensor *tensor) { + auto *autograd_meta = tensor->get_autograd_meta(); + const auto &grad_node = + static_cast(autograd_meta)->GetMutableGradNode(); + return grad_node; +} + +void EagerReducer::InitializeGroups( + const std::vector> &group_indices) { + VLOG(3) << "Start initialize groups .."; + + // clear the group + groups_.clear(); + groups_.reserve(group_indices.size()); + + variable_locators_.clear(); + variable_locators_.resize(tensors_.size()); + + auto group_nums = group_indices.size(); + for (size_t group_index = 0; group_index < group_nums; ++group_index) { + const auto &tensor_indices_ = group_indices[group_index]; + PADDLE_ENFORCE_GT( + tensor_indices_.size(), 0, + platform::errors::PreconditionNotMet( + "The number of group[%d]'s elements is 0.", group_index)); + + EagerGroup group; + + // It's just for check the sparse or dense + auto first_var = tensors_[tensor_indices_.front()]; + if (tensor_indices_.size() == 1 && + is_sparse_gradient_[tensor_indices_.front()]) { + // process the sparse gradient. one sparse, one group + group.dtype_ = first_var.dtype(); + } else { + // process the dense gradient. + InitializeDenseGroups(tensor_indices_, &group); + experimental::Backend backend; + switch (inner_place_.GetType()) { + case phi::AllocationType::GPU: + backend = experimental::Backend::GPU; + break; + case phi::AllocationType::CPU: + backend = experimental::Backend::CPU; + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Place type (%s) is not supported. ", inner_place_)); + break; + } + group.dense_contents_ = paddle::experimental::empty( + ScalarArray({group.all_length_}), group.dtype_, backend); + } + + // map tensors to this group by VariableLocator + size_t inside_group_index = 0; + for (const auto var_index : tensor_indices_) { + TensorLocator tensor_locator; + tensor_locator.group_index = group_index; + tensor_locator.inside_group_index = inside_group_index++; + variable_locators_[var_index] = tensor_locator; + } + group.tensor_indices_ = std::move(tensor_indices_); + groups_.emplace_back(std::move(group)); + + VLOG(3) << "The Group[" << group_index << "]:" << groups_.back(); + } +} + +void EagerReducer::InitializeDenseGroups( + const std::vector &tensor_indices_, EagerGroup *p_group) { + VLOG(3) << "InitializeDenseGroups."; + int64_t all_length = 0; + for (size_t index = 0; index < tensor_indices_.size(); ++index) { + auto tensor_index = tensor_indices_[index]; + auto &tensor = tensors_[tensor_index]; + auto &tensor_name = tensor.name(); + + PADDLE_ENFORCE_EQ(tensor.is_initialized(), true, + platform::errors::PreconditionNotMet( + "Tensor %s is not initialized.", tensor_name)); + const auto size = tensor.numel(); + PADDLE_ENFORCE_GT( + size, 0, platform::errors::PreconditionNotMet( + "The number of tensor %s's elements is 0.", tensor_name)); + all_length += size; + + p_group->length_.push_back(size); + + // for concat operator + p_group->origin_shapes_.push_back(ScalarArray(tensor.shape())); + p_group->dense_tensors_.push_back(phi::DenseTensor()); + + const auto &dtype = tensor.dtype(); + const auto &place = tensor.place(); + const auto &inner_place = tensor.impl()->place(); + if (index > 0) { + PADDLE_ENFORCE_EQ(dtype, p_group->dtype_, + platform::errors::PreconditionNotMet( + "Tensor %s has unexpected dtype.", tensor_name)); + PADDLE_ENFORCE_EQ(place, place_, + platform::errors::PreconditionNotMet( + "Tensor %s has different place. Expected place is " + "%s, but actual place is %s", + tensor_name, inner_place_, inner_place)); + } else { + p_group->dtype_ = dtype; + place_ = place; + inner_place_ = inner_place; + } + } + p_group->all_length_ = all_length; +} + +void EagerReducer::PrepareForBackward(const std::vector &outputs) { + VLOG(3) << "after forward, then reset count for backward."; + grad_need_hooks_ = true; + next_group_ = 0; + std::for_each(groups_.begin(), groups_.end(), [](EagerGroup &group) { + group.pending_ = group.tensor_indices_.size(); + }); + + // reinitialize vars_marked_ready_ for next iteration + vars_marked_ready_.clear(); + vars_marked_ready_.resize(tensors_.size(), false); +} + +void EagerReducer::AddDistHook(size_t var_index) { + PADDLE_ENFORCE_LT(var_index, variable_locators_.size(), + platform::errors::OutOfRange( + "Out of bounds variable index. it must be less" + "than %d, but it is %d", + variable_locators_.size(), var_index)); + + // gradient synchronization is not required when grad_need_hooks_ is false. + if (!grad_need_hooks_) { + return; + } + + auto &tensor = tensors_[var_index]; + const auto &grad_node = GetGradNodeFromTensor(&tensor); + + VLOG(3) << "Var[" << var_index << "] [" << (*grad_node).name() + << "] arrived and triggered disthook"; + + local_used_vars_[var_index] = 1; + + MarkVarReady(var_index, true); +} + +void EagerReducer::MarkVarReady(const size_t var_index, + const bool is_used_var) { + const auto &var_locator = variable_locators_[var_index]; + const auto group_index = var_locator.group_index; + const auto inside_group_index = var_locator.inside_group_index; + + auto &group = groups_[group_index]; + auto &group_tensor = group.dense_tensors_[inside_group_index]; + auto *autograd_meta = tensors_[var_index].get_autograd_meta(); + auto &grad_tensor = static_cast(autograd_meta)->Grad(); + + group_tensor + .ShareDataWith( + *(std::dynamic_pointer_cast(grad_tensor.impl()))) + .Resize({grad_tensor.numel()}); + + vars_marked_ready_[var_index] = true; + + if (--group.pending_ == 0) { + // can start allreduce + MarkGroupReady(group_index); + } +} + +void EagerReducer::MarkGroupReady(size_t group_index) { + VLOG(3) << "Group[" << group_index << "] is ready"; + + PADDLE_ENFORCE_GE( + group_index, next_group_, + platform::errors::PreconditionNotMet( + "The index of the incoming group must be greater " + "than or equal to the previously synchronized group index, " + "expect it to greater than or equal to %d, but got %d.", + next_group_, group_index)); + + if (group_index > next_group_) { + VLOG(3) << "It will adjust the order of group in next batch automatically"; + return; + } + + for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0; + ++next_group_) { + UNUSED auto &group = groups_[next_group_]; + FusedAllReduceSchedule(&group, next_group_); + } +} + +void EagerReducer::FusedAllReduceSchedule(EagerGroup *group, + const int curr_group_index) { + // The overall timeline: concat > div_nranks > allreduce > split + distributed::AllreduceOptions opts; + opts.reduce_op = ReduceOp::SUM; + + VLOG(3) << "group [" << curr_group_index << "] start fused_allreduce."; + + // concat tensors + group->ConcatTensors(inner_place_); + + // div nranks + double scaling = 1.0 / nranks_; + paddle::experimental::scale_(group->dense_contents_, scaling, 0.0, false); + + // all_reduce + std::vector reduce_tensors = {group->dense_contents_}; + tasks_.push_back(process_group_->AllReduce(reduce_tensors, opts)); + + if (tasks_.size() == groups_.size()) { + for (size_t index = 0; index < tasks_.size(); index++) { + auto &task = tasks_.back(); + task->Synchronize(); + tasks_.pop_back(); + } + for (size_t index = 0; index < groups_.size(); index++) { + auto &group = groups_[index]; + group.SplitTensors(inner_place_); + } + } +} + +std::ostream &operator<<(std::ostream &out, const EagerGroup &group) { + const auto &tensors_ = group.tensor_indices_; + out << "numel: " << group.all_length_ << " ;var number: " << tensors_.size() + << "\n"; + auto begin = tensors_.begin(); + auto end = tensors_.end(); + out << "["; + for (int i = 0; begin != end && i < 100; ++i, ++begin) { + if (i > 0) out << ' '; + out << *begin; + } + if (begin != end) { + out << " ..."; + } + out << "]\n"; + return out; +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/reducer.h b/paddle/fluid/distributed/collective/reducer.h index f8c75385ef..ac6f3fbe59 100644 --- a/paddle/fluid/distributed/collective/reducer.h +++ b/paddle/fluid/distributed/collective/reducer.h @@ -17,16 +17,109 @@ #include #include #include "paddle/fluid/distributed/collective/ProcessGroup.h" +#include "paddle/fluid/eager/accumulation/accumulation_node.h" +#include "paddle/fluid/eager/api/utils/hook_utils.h" #include "paddle/fluid/eager/api/utils/tensor_utils.h" +#include "paddle/fluid/eager/autograd_meta.h" +#include "paddle/fluid/eager/utils.h" +#include "paddle/fluid/operators/math/concat_and_split.h" +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/phi/api/include/api.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/api/lib/ext_compat_utils.h" +#include "paddle/phi/common/data_type.h" namespace paddle { namespace distributed { using Tensor = paddle::experimental::Tensor; +using Scalar = paddle::experimental::ScalarBase; +using ScalarArray = + paddle::experimental::ScalarArrayBase; std::vector> Eager_AssignGroupBySize( - const std::vector, const std::vector& is_sparse_gradient, - const std::vector& group_size_limits, - const std::vector& tensor_indices = {}); + const std::vector, const std::vector &is_sparse_gradient, + const std::vector &group_size_limits, + const std::vector &tensor_indices = {}); + +class EagerGroup { + public: + Tensor dense_contents_; + + // for concat kernel + std::vector dense_tensors_; + std::vector length_; + int64_t all_length_{0}; + std::vector origin_shapes_; + + // Global indices of participating tensors in the group + std::vector tensor_indices_; + + // Number of params that haven't been ready. When it is 0, it means + // the group is ready. + size_t pending_ = -1; + + // external message of group + phi::DataType dtype_; + + // context is used to select the stream for concat + void ConcatTensors(const platform::Place &); + + // context is used to select the stream for split + void SplitTensors(const platform::Place &); + + friend std::ostream &operator<<(std::ostream &, const EagerGroup &); +}; + +struct TensorLocator { + // record the index in groups_ + size_t group_index; + size_t inside_group_index; +}; + +class EagerReducer { + public: + explicit EagerReducer( + const std::vector tensors, + const std::vector> &group_indices, + const std::vector &is_sparse_gradient, + std::shared_ptr process_group, + const std::vector &group_size_limits, + bool find_unused_parameters); + + virtual ~EagerReducer() {} + + std::shared_ptr GetGradNodeFromTensor(Tensor *tensor); + + void InitializeGroups(const std::vector> &group_indices); + void InitializeDenseGroups(const std::vector &tensor_indices_, + EagerGroup *p_group); + void PrepareForBackward(const std::vector &outputs); + void AddDistHook(size_t var_index); + void MarkVarReady(const size_t var_index, const bool is_used_var); + void MarkGroupReady(const size_t group_index); + void FusedAllReduceSchedule(EagerGroup *group, const int curr_group_index); + + private: + std::vector tensors_; + std::vector> group_indices_; + std::vector is_sparse_gradient_; + std::shared_ptr process_group_; + std::vector group_size_limits_; + bool find_unused_vars_each_step_; + + std::vector groups_; + std::vector variable_locators_; + PlaceType place_; + platform::Place inner_place_; + size_t next_group_ = 0; + int64_t nranks_ = -1; + std::vector> tasks_; + + bool grad_need_hooks_{false}; + + std::vector vars_marked_ready_; + std::vector local_used_vars_; +}; } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 0b17967038..1df917b8c3 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -51,6 +51,18 @@ namespace pybind { using Tensor = paddle::experimental::Tensor; +std::shared_ptr CreateEagerReducer( + py::handle py_tensors, + const std::vector> &group_indices, + const std::vector &is_sparse_gradient, + std::shared_ptr process_group, + const std::vector &group_size_limits, bool find_unused_parameters) { + auto params = CastPyArg2VectorOfTensor(py_tensors.ptr(), 0); + return std::make_shared( + params, group_indices, is_sparse_gradient, process_group, + group_size_limits, find_unused_parameters); +} + #if defined(PADDLE_WITH_GLOO) using ProcessGroupGloo = paddle::distributed::ProcessGroupGloo; using GlooStore = paddle::distributed::ProcessGroupGloo::GlooStore; @@ -271,6 +283,17 @@ void BindDistributed(py::module *m) { py::arg("group_size_limits") = std::vector{25 * 1024 * 1024}, py::arg("tensor_indices") = std::vector{}, py::call_guard()); + + py::class_>(*m, "EagerReducer", + R"DOC()DOC") + .def(py::init(&CreateEagerReducer)) + .def("prepare_for_backward", + [](distributed::EagerReducer &self, py::handle py_tensors) { + auto params = CastPyArg2VectorOfTensor(py_tensors.ptr(), 0); + self.PrepareForBackward(params); + }, + py::arg("tensors"), py::call_guard()); } } // end namespace pybind diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 652916491e..86d76f1b20 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -30,7 +30,7 @@ from paddle.fluid.dygraph import to_variable, no_grad from paddle.utils import deprecated from ..layers import collective from paddle.fluid.dygraph import base as imperative_base -from paddle.fluid.framework import ParamBase +from paddle.fluid.framework import ParamBase, _in_eager_mode __all__ = ["prepare_context", "ParallelEnv", "DataParallel"] @@ -397,6 +397,16 @@ def sync_params_buffers(model, 'axis': 0}) +@imperative_base.no_grad +@framework.dygraph_only +def sync_eager_params(model, comm_group=None, src_rank=0): + for _, param in model._obtain_parameters_buffers().items(): + if not isinstance(param, core.eager.Tensor): + raise TypeError("The data type of '%s' must be '%s'" % + (param.name, core.eager.Tensor)) + comm_group.broadcast(param, src_rank).synchronize() + + class DataParallel(layers.Layer): """ Run the dygraph module with data parallelism. @@ -576,6 +586,7 @@ class DataParallel(layers.Layer): self.process_group = process_group self.gradient_as_buffer_view = gradient_as_buffer_view self.static_graph = static_graph + self.var_dtype = core.eager.Tensor if _in_eager_mode() else core.VarBase # NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy. # It just stores some environment variables, which can be constructed by @@ -592,11 +603,20 @@ class DataParallel(layers.Layer): "ParallelContext must be initialized before. You should use init_parallel_env() before" \ "constructing the DataParallel." + if self.process_group is None and _in_eager_mode(): + raise RuntimeError( + "Process group should be built in DataParallel of eager mode." + ) + # sync buffer and params # TODO(liuyuhui) Currently not support xpu. xpu is # still broadcasting parameters when calling layer if not paddle.is_compiled_with_xpu(): - sync_params_buffers(self._layers) + if _in_eager_mode(): + sync_eager_params( + self._layers, comm_group=self.process_group) + else: + sync_params_buffers(self._layers) self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024) # NOTE(shenliang03): We can set environment variables to control @@ -620,9 +640,9 @@ class DataParallel(layers.Layer): if param is None or param in params_set: continue params_set.add(param) - if not isinstance(param, core.VarBase): - raise TypeError("The data type of '%s' must be Varbase" % - param.name) + if not isinstance(param, self.var_dtype): + raise TypeError("The data type of '%s' must be '%s'" % + (param.name, self.var_dtype)) if param.trainable: layers_param.append((sublayer, param)) @@ -649,19 +669,32 @@ class DataParallel(layers.Layer): check_layer_sparse(sublayer) for sublayer, _ in layers_param ] - self.group_indices = core.assign_group_by_size( - trainable_parameters, is_sparse_gradient, - [self.last_comm_buffer_size, self.comm_buffer_size]) + if _in_eager_mode(): + self.group_indices = core.eager_assign_group_by_size( + trainable_parameters, is_sparse_gradient, + [self.last_comm_buffer_size, self.comm_buffer_size]) + + self._reducer = core.EagerReducer( + trainable_parameters, + list(reversed(self.group_indices)), is_sparse_gradient, + self.process_group, + [self.last_comm_buffer_size, self.comm_buffer_size], + self.find_unused_parameters) + else: + self.group_indices = core.assign_group_by_size( + trainable_parameters, is_sparse_gradient, + [self.last_comm_buffer_size, self.comm_buffer_size]) - self._reducer = core.Reducer( - trainable_parameters, - list(reversed(self.group_indices)), is_sparse_gradient, - parallel_helper.__parallel_ctx__clz__, - [self.last_comm_buffer_size, self.comm_buffer_size], - self.find_unused_parameters) + self._reducer = core.Reducer( + trainable_parameters, + list(reversed(self.group_indices)), is_sparse_gradient, + parallel_helper.__parallel_ctx__clz__, + [self.last_comm_buffer_size, self.comm_buffer_size], + self.find_unused_parameters) def _find_varbase(self, obj): - if isinstance(obj, core.VarBase): + var_type = core.eager.Tensor if _in_eager_mode() else core.VarBase + if isinstance(obj, var_type): return [obj] if isinstance(obj, (list, tuple)): return itertools.chain(*map(self._find_varbase, obj)) diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_dataparallel_in_eager_mode.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_dataparallel_in_eager_mode.py new file mode 100644 index 0000000000..8ff68a1ce0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_dataparallel_in_eager_mode.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import os +import numpy as np +import random + +import paddle +import paddle.nn as nn +from paddle.fluid.dygraph.nn import Linear +import paddle.fluid.core as core +from paddle.fluid.framework import _test_eager_guard +import paddle.distributed as dist +from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.optimizer import SGD +from paddle.fluid.initializer import NumpyArrayInitializer + + +def init_process_group(strategy=None): + nranks = ParallelEnv().nranks + rank = ParallelEnv().local_rank + is_master = True if rank == 0 else False + store = paddle.fluid.core.TCPStore("127.0.0.1", 6172, is_master, nranks) + group = core.ProcessGroupNCCL(store, rank, nranks) + return group + + +class LinearModel(nn.Layer): + def __init__(self, attr_list): + super(LinearModel, self).__init__() + self._linear1 = paddle.nn.Linear( + 50, 30, weight_attr=attr_list[0], bias_attr=False) + self._linear2 = paddle.nn.Linear( + 30, 10, weight_attr=attr_list[1], bias_attr=False) + self._linear3 = paddle.nn.Linear( + 10, 10, weight_attr=attr_list[2], bias_attr=False) + + def forward(self, x): + output = self._linear1(x) + output = self._linear2(output) + output = self._linear3(output) + return output + + +class TestDistTraning(unittest.TestCase): + def test_multiple_gpus(self): + process_group = init_process_group() + self.generate_reducer("float32", process_group) + self.generate_reducer("float16", process_group) + + def generate_reducer(self, dtype, process_group): + dev_id = ParallelEnv().dev_id + np.random.seed(2022 + dev_id) + paddle.set_default_dtype(dtype) + + w_1 = paddle.ParamAttr(initializer=NumpyArrayInitializer( + np.random.rand(50, 30).astype(dtype))) + w_2 = paddle.ParamAttr(initializer=NumpyArrayInitializer( + np.random.rand(30, 10).astype(dtype))) + w_3 = paddle.ParamAttr(initializer=NumpyArrayInitializer( + np.random.rand(10, 10).astype(dtype))) + + attr_list = [w_1, w_2, w_3] + inp = np.random.rand(10, 50).astype(dtype) + + # original reducer + params_a = self.model_train(attr_list, inp) + + # refactored reducer in eager mode + with _test_eager_guard(): + params_b = self.model_train( + attr_list, inp, process_group=process_group) + + for i in range(len(params_a)): + np.testing.assert_allclose(params_a[i].numpy(), params_b[i].numpy()) + + def model_train(self, attr_list, inp, process_group=None): + model = LinearModel(attr_list) + model = paddle.DataParallel(model, process_group=process_group) + optimizer = SGD(learning_rate=0.0003, parameters=model.parameters()) + + x = paddle.to_tensor(inp) + x.stop_gradient = False + + for step in range(10): + y = model(x) + loss = y.mean() + + loss.backward() + optimizer.step() + optimizer.clear_grad() + + return model.parameters() + + +class TestCatchErrors1(unittest.TestCase): + def test_multiple_gpus(self): + linear = paddle.nn.Linear(2, 4) + with _test_eager_guard(): + self.assertRaises(RuntimeError, paddle.DataParallel, linear) + + +class TestCatchErrors2(unittest.TestCase): + def test_multiple_gpus(self): + with _test_eager_guard(): + linear = paddle.nn.Linear(2, 4) + self.assertRaises(RuntimeError, paddle.DataParallel, linear) + + +if __name__ == '__main__': + dist.init_parallel_env() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py index edf9aed04f..802fcc9628 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py @@ -200,5 +200,10 @@ class TestDataParallelWithPyLayer(TestMultipleGpus): self.run_mnist_2gpu('parallel_dygraph_dataparallel_with_pylayer.py') +class TestDataParallelInEagerMode(TestMultipleGpus): + def test_multiple_gpus_dynamic(self): + self.run_mnist_2gpu('parallel_dygraph_dataparallel_in_eager_mode.py') + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 47dc02705f..96f35eb9d2 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -42,6 +42,7 @@ from .. import compat as cpt from .lr import LRScheduler import copy from paddle import _C_ops +from paddle.fluid.framework import _in_eager_mode __all__ = [] @@ -1108,7 +1109,13 @@ class Optimizer(object): for p in param_group['params']: if not p.stop_gradient: param_list.append(p) - core.clear_gradients(param_list, set_to_zero) + + if _in_eager_mode(): + for p in param_list: + clear_func = p._zero_grads if set_to_zero else p.clear_gradient + clear_func() + else: + core.clear_gradients(param_list, set_to_zero) @imperative_base.no_grad def minimize(self, -- GitLab