From e2d01eb650dba6267046c1cfd6e64cf8cfd74267 Mon Sep 17 00:00:00 2001 From: ShenLiang Date: Fri, 27 Nov 2020 19:01:21 +0800 Subject: [PATCH] Support dynamic graph distributed (#28997) * add reducer * refine envent for memorycopy * add concat&split for allreduce * apply concat & split for fuse tensor * fix nccl dep * fix the untest, compile problem and ddp initialize problem * fix untest for mac & add some comments & solve the repeated param in sublayers * fix untest for windows & fix document --- paddle/fluid/imperative/CMakeLists.txt | 6 +- paddle/fluid/imperative/all_reduce.cc | 3 + paddle/fluid/imperative/all_reduce.h | 3 + paddle/fluid/imperative/nccl_context.cc | 55 ++- paddle/fluid/imperative/nccl_context.h | 28 +- paddle/fluid/imperative/reducer.cc | 356 ++++++++++++++++++ paddle/fluid/imperative/reducer.h | 225 +++++++++++ paddle/fluid/platform/collective_helper.cc | 1 + paddle/fluid/platform/collective_helper.h | 1 + paddle/fluid/pybind/CMakeLists.txt | 1 + paddle/fluid/pybind/imperative.cc | 29 +- .../distributed/fleet/base/fleet_base.py | 14 +- python/paddle/fluid/dygraph/parallel.py | 126 ++++--- python/paddle/fluid/optimizer.py | 4 - .../fluid/tests/unittests/CMakeLists.txt | 4 + .../parallel_dygraph_sparse_embedding.py | 9 +- .../parallel_dygraph_sparse_embedding_fp64.py | 56 +++ .../fluid/tests/unittests/test_fleet_base.py | 23 +- .../tests/unittests/test_fleet_base_single.py | 1 - .../tests/unittests/test_imperative_group.py | 160 ++++++++ .../test_parallel_dygraph_sparse_embedding.py | 16 + python/paddle/hapi/model.py | 3 +- python/paddle/optimizer/adam.py | 4 - python/paddle/optimizer/adamw.py | 4 - python/paddle/optimizer/optimizer.py | 7 - 25 files changed, 1029 insertions(+), 110 deletions(-) create mode 100644 paddle/fluid/imperative/reducer.cc create mode 100644 paddle/fluid/imperative/reducer.h create mode 100644 python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py create mode 100644 python/paddle/fluid/tests/unittests/test_imperative_group.py diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 3d01e4fe46f..2da8169ebd9 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -2,7 +2,6 @@ cc_library(imperative_flag SRCS flags.cc DEPS gflags) cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform) cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry) -cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows selected_rows_functor var_type_traits layer math_function) add_subdirectory(jit) cc_library(amp SRCS amp_auto_cast.cc DEPS layer ) cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp) @@ -12,9 +11,12 @@ cc_library(imperative_profiler SRCS profiler.cc) if(NOT WIN32) if(WITH_NCCL) cc_library(imperative_all_reduce SRCS all_reduce.cc DEPS collective_helper device_context selected_rows tensor) - cc_library(nccl_context SRCS nccl_context.cc DEPS collective_helper device_context imperative_all_reduce) + cc_library(nccl_context SRCS nccl_context.cc DEPS collective_helper device_context imperative_all_reduce var_type_traits) + cc_library(reducer SRCS reducer.cc DEPS layer imperative_all_reduce) endif() cc_library(data_loader SRCS data_loader.cc DEPS enforce) endif(NOT WIN32) +cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows selected_rows_functor var_type_traits layer math_function) + add_subdirectory(tests) diff --git a/paddle/fluid/imperative/all_reduce.cc b/paddle/fluid/imperative/all_reduce.cc index 0a601417de1..2c39ff6e86d 100644 --- a/paddle/fluid/imperative/all_reduce.cc +++ b/paddle/fluid/imperative/all_reduce.cc @@ -72,7 +72,9 @@ static void AllReduce(const framework::SelectedRows &src, const auto &src_rows = src.rows(); framework::Vector rows_num_vector(strategy.nranks_); rows_num_vector[strategy.local_rank_] = static_cast(src_rows.size()); + // CUDAMutableData use CalStream auto *gpu_rows_num_ptr = rows_num_vector.CUDAMutableData(place); + if (stream != dev_ctx->stream()) dev_ctx->Wait(); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( gpu_rows_num_ptr + strategy.local_rank_, gpu_rows_num_ptr, 1, ncclInt64, comm, stream)); @@ -106,6 +108,7 @@ static void AllReduce(const framework::SelectedRows &src, auto sizeof_dtype = framework::SizeOfType(dtype); int64_t row_offset = 0; + if (stream != dev_ctx->stream()) dev_ctx->Wait(); for (int i = 0; i < strategy.nranks_; ++i) { if (cpu_rows_num_ptr[i] > 0) { // 2. Broadcast the rows of SelectedRows diff --git a/paddle/fluid/imperative/all_reduce.h b/paddle/fluid/imperative/all_reduce.h index 249fb4e11f1..bd94e78f461 100644 --- a/paddle/fluid/imperative/all_reduce.h +++ b/paddle/fluid/imperative/all_reduce.h @@ -39,6 +39,9 @@ struct ParallelStrategy; void AllReduce(const framework::Variable &src, framework::Variable *dst, const ParallelStrategy &strategy); +void AllReduce(const framework::Variable &src, framework::Variable *dst, + const ParallelStrategy &strategy, cudaStream_t stream); + } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc index 9c2c9925a34..e7c7b693707 100644 --- a/paddle/fluid/imperative/nccl_context.cc +++ b/paddle/fluid/imperative/nccl_context.cc @@ -14,8 +14,6 @@ #include "paddle/fluid/imperative/nccl_context.h" -#include "paddle/fluid/platform/collective_helper.h" - namespace paddle { namespace imperative { #if defined(PADDLE_WITH_NCCL) @@ -168,22 +166,51 @@ void NCCLParallelContext::BcastNCCLId(ncclUniqueId *nccl_id, int root) { } void NCCLParallelContext::Init() { - ncclUniqueId nccl_id; - if (strategy_.local_rank_ == 0) { - // generate the unique ncclid on the root worker - platform::dynload::ncclGetUniqueId(&nccl_id); - BcastNCCLId(&nccl_id, 0); + for (int ring_id = 0; ring_id < strategy_.nrings_; ring_id++) { + ncclUniqueId nccl_id; + if (strategy_.local_rank_ == 0) { + // generate the unique ncclid on the root worker + platform::dynload::ncclGetUniqueId(&nccl_id); + BcastNCCLId(&nccl_id, 0); + } else { + BcastNCCLId(&nccl_id, 0); + } + int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device; + VLOG(0) << "init nccl context nranks: " << strategy_.nranks_ + << " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id + << " ring id: " << ring_id; + + // it will assign nccl_comm in CUDADeviceContext within ring_id + platform::NCCLCommContext::Instance().CreateNCCLComm( + &nccl_id, strategy_.nranks_, strategy_.local_rank_, gpu_id, ring_id); + } +} + +void NCCLParallelContext::AllReduceByStream(const framework::Variable &src, + framework::Variable *dst, + int ring_id, bool use_calc_stream) { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(place_), true, + platform::errors::Unimplemented( + "Dynamic graph mode does not support multi-CPU training yet.")); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place_); + cudaStream_t stream = nullptr; + if (use_calc_stream) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place_); + stream = static_cast(dev_ctx)->stream(); } else { - BcastNCCLId(&nccl_id, 0); + stream = comm->stream(); } - int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device; - VLOG(0) << "init nccl context nranks: " << strategy_.nranks_ - << " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id; + AllReduce(src, dst, strategy_, stream); +} - // it will assign nccl_comm in CUDADeviceContext within ring_id 0 - platform::NCCLCommContext::Instance().CreateNCCLComm( - &nccl_id, strategy_.nranks_, strategy_.local_rank_, gpu_id, 0); +paddle::platform::CUDADeviceContext *NCCLParallelContext::GetDeviceContext( + int ring_id) { + return platform::NCCLCommContext::Instance() + .Get(ring_id, place_) + ->dev_context(); } + #endif } // namespace imperative diff --git a/paddle/fluid/imperative/nccl_context.h b/paddle/fluid/imperative/nccl_context.h index cbd169f8da7..ebb1b17643f 100644 --- a/paddle/fluid/imperative/nccl_context.h +++ b/paddle/fluid/imperative/nccl_context.h @@ -23,15 +23,25 @@ #endif #include +#include #include +#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/device_context.h" + #if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/imperative/all_reduce.h" #include "paddle/fluid/platform/dynload/nccl.h" +#include "paddle/fluid/platform/nccl_helper.h" #endif + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/split.h" +#include "paddle/fluid/string/string_helper.h" namespace paddle { namespace imperative { @@ -41,6 +51,8 @@ struct ParallelStrategy { int local_rank_{0}; std::vector trainer_endpoints_{}; std::string current_endpoint_{""}; + // TODO(shenliang03): support multi stream communication + int nrings_{1}; }; class ParallelContext { @@ -53,13 +65,21 @@ class ParallelContext { virtual void Init() = 0; + virtual void AllReduceByStream(const framework::Variable& src, + framework::Variable* dst, int ring_id = 0, + bool use_calc_stream = false) = 0; +#if defined(PADDLE_WITH_NCCL) + virtual paddle::platform::CUDADeviceContext* GetDeviceContext( + int ring_id) = 0; +#endif + protected: ParallelStrategy strategy_; platform::Place place_; }; #if defined(PADDLE_WITH_NCCL) -class NCCLParallelContext : ParallelContext { +class NCCLParallelContext : public ParallelContext { public: explicit NCCLParallelContext(const ParallelStrategy& strategy, const platform::Place& place) @@ -71,6 +91,12 @@ class NCCLParallelContext : ParallelContext { void Init() override; + void AllReduceByStream(const framework::Variable& src, + framework::Variable* dst, int ring_id, + bool use_calc_stream) override; + + paddle::platform::CUDADeviceContext* GetDeviceContext(int ring_id) override; + protected: void RecvNCCLID(const std::string& endpoint, ncclUniqueId* nccl_id); diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc new file mode 100644 index 00000000000..71d68fa2e0d --- /dev/null +++ b/paddle/fluid/imperative/reducer.cc @@ -0,0 +1,356 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/imperative/reducer.h" + +namespace paddle { +namespace imperative { + +#if defined(PADDLE_WITH_NCCL) +std::shared_ptr Reducer::s_instance_ = NULL; + +Reducer::Reducer(const std::vector> &vars, + const std::vector> &group_indices, + const std::vector &is_sparse_gradient, + std::shared_ptr parallel_ctx) + : vars_(vars), + group_indices_(group_indices), + is_sparse_gradient_(is_sparse_gradient), + parallel_ctx_(parallel_ctx) { + VLOG(3) << "Start construct the Reducer ..."; + // initialize groups + InitializeGroups(group_indices); + + { + for (size_t group_index = 0; group_index < group_indices.size(); + ++group_index) { + for (size_t var_index = 0; var_index < group_indices[group_index].size(); + ++var_index) { + size_t global_var_index = group_indices[group_index][var_index]; + const auto variable_index = VariableIndex{ + .group_index = group_index, .inside_group_index = var_index, + }; + VLOG(3) << "add hook for var[" << vars_[global_var_index]->GradVarName() + << "], it's in group [" << group_index << "]"; + vars_[global_var_index]->SharedVar()->AddGradVarLeafBackwardHook( + std::unique_ptr( + new LambdaGradAccumulatorPostHook([=](VariableWrapper *grad) { + this->AddDistHook(grad, variable_index); + }))); + } + } + } + + compute_stream_ = static_cast( + platform::DeviceContextPool::Instance().Get(place_)) + ->stream(); + comm_stream_ = platform::NCCLCommContext::Instance().Get(0, place_)->stream(); + events_.resize(group_indices.size()); + for (auto &event : events_) { + event = platform::CudaEventResourcePool::Instance().New( + BOOST_GET_CONST(platform::CUDAPlace, place_).device); + } + comm_enent_ = platform::CudaEventResourcePool::Instance().New( + BOOST_GET_CONST(platform::CUDAPlace, place_).device); + + std::call_once(once_flag_, []() { + std::atexit([]() { Reducer::GetInstance()->ReleaseReducer(); }); + }); +} + +void Reducer::ReleaseReducer() { + for (auto &event : events_) { + event.reset(); + } + comm_enent_.reset(); +} + +int64_t Reducer::InitializeDenseGroups( + const std::vector &variable_indices_, Group *p_group) { + int64_t all_length = 0; + for (size_t index = 0; index < variable_indices_.size(); ++index) { + const auto variable_index = variable_indices_[index]; + const auto &var = vars_[variable_index]; + const auto var_name = var->Name(); + PADDLE_ENFORCE_EQ(is_sparse_gradient_[variable_index], false, + platform::errors::PreconditionNotMet( + "Tensor `%s`'s GRAD must be LoDTensor, but received " + "GRAD is SelectedRows", + var_name)); + + auto lod_tensor = var->MutableVar()->GetMutable(); + PADDLE_ENFORCE_EQ(lod_tensor->IsInitialized(), true, + platform::errors::PreconditionNotMet( + "Tensor `%s` is not initialized.", var_name)); + auto size = lod_tensor->numel(); + PADDLE_ENFORCE_GT( + size, 0, platform::errors::PreconditionNotMet( + "The number of tensor `%s`'s elements is 0.", var_name)); + all_length += size; + + p_group->length_.push_back(size); + // for concat operator + p_group->dense_tensors_.push_back(framework::Tensor()); + + // check the dtype and place, it must be same. + auto dtype = var->DataType(); + auto place = var->Place(); + if (index > 0) { + PADDLE_ENFORCE_EQ( + dtype, p_group->dtype_, + platform::errors::PreconditionNotMet( + "Tensor %s has different dtype. Expected dtype is %s, but actual " + "dtype is %s", + var_name, framework::DataTypeToString(p_group->dtype_), + framework::DataTypeToString(dtype))); + PADDLE_ENFORCE_EQ(place, place_, + platform::errors::PreconditionNotMet( + "Tensor %s has different place. Expected place is " + "%s, but actual place is %s", + var_name, place_, place)); + } else { + p_group->dtype_ = dtype; + place_ = place; + } + } + return all_length; +} + +// Each parameter will be initialized according to the group information. +// For the sparse parameter, sparse_contents_ in the group directly points +// to the parameter. For dense parameters, first construct an empty Tensor(). +// Then specify the actual memory in MarkVariableReady. +void Reducer::InitializeGroups( + const std::vector> &group_indices) { + VLOG(3) << "Start initialize groups .."; + // clear the group + groups_.clear(); + groups_.reserve(group_indices.size()); + + auto group_nums = group_indices.size(); + for (size_t group_index = 0; group_index < group_nums; ++group_index) { + const auto &variable_indices_ = group_indices[group_index]; + PADDLE_ENFORCE_GT( + variable_indices_.size(), 0, + platform::errors::PreconditionNotMet( + "The number of group_index[`%d`]'s elements is 0.", group_index)); + Group group; + group.variable_indices_ = variable_indices_; + int64_t all_length = 0; + + // It's just for check the sparse or dense + auto first_varbase = vars_[variable_indices_.front()]; + if (variable_indices_.size() == 1 && + is_sparse_gradient_[variable_indices_.front()]) { + // process the sparse gradient. one sparse, one group + group.sparse_contents_ = first_varbase->MutableGradVar(); + group.dtype_ = first_varbase->DataType(); + group.is_sparse_ = true; + } else { + // process the dense gradient. + all_length = InitializeDenseGroups(variable_indices_, &group); + // Alloc the continuous space + auto tensor = group.dense_contents_.GetMutable(); + tensor->Resize(framework::make_ddim({all_length})) + .mutable_data(place_, group.dtype_); + } + // Debug Message For Reducer + VLOG(3) << "the groups_[" << group_index << "] basic message:"; + VLOG(3) << "numul: " << all_length << " ;is_sparse: " << group.is_sparse_ + << " ;var number: " << group.variable_indices_.size(); + groups_.emplace_back(std::move(group)); + } +} + +// After each batch is calculated, the counter of each group(group.pending_) +// and allreudce sequence counter(next_group_) will be cleaned up again. +void Reducer::PrepareForBackward() { + VLOG(3) << "start reseting count.."; + next_group_ = 0; + std::for_each(groups_.begin(), groups_.end(), [](Group &group) { + group.pending_ = group.variable_indices_.size(); + }); +} + +// Add hook function to each leaf node. When the gradient of a leaf node is +// generated, if it is the sparse parameter, it will directly execute allreduce, +// if it is the dense parameter, it will execute three steps: 1, +// MarkVariableReady. Find the position of the corresponding group +// through var_index, share the gradient memory and the group dense_tensors, +// the group counter is reduced by 1. 2, MarkGroupReady: When the group +// counter is 0, it means that allreduce can be emitted, and +// concat + allreduce + split is emitted in turn according to next_group_. +// 3, FinalizeBackward: after the end, synchronize each stream. +void Reducer::AddDistHook(VariableWrapper *var_warpper, + const VariableIndex &var_index) { + auto group_index = var_index.group_index; + auto &group = groups_[group_index]; + + if (!group.is_sparse_) { + // Only dense_contents_ need memory copy + MarkVariableReady(var_index, var_warpper); + } + if (--group.pending_ == 0) { + // can start allreduce + MarkGroupReady(group_index); + } + + if (next_group_ == groups_.size()) { + FinalizeBackward(); + } +} + +void Reducer::MarkVariableReady(const VariableIndex &var_index, + VariableWrapper *var_warpper) { + auto group_index = var_index.group_index; + auto variable_index = var_index.inside_group_index; + auto &group = groups_[group_index]; + auto length = group.length_[variable_index]; + + auto tensor = var_warpper->MutableVar()->GetMutable(); + group.dense_tensors_[variable_index].ShareDataWith(*tensor).Resize( + {static_cast(length)}); +} + +void Reducer::MarkGroupReady(size_t group_index) { + if (group_index > next_group_) { + LOG(WARNING) << "Maybe it need adjust the order of group"; + return; + } + + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaEventRecord(events_[group_index].get(), compute_stream_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamWaitEvent(comm_stream_, events_[group_index].get(), 0)); + + for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0; + ++next_group_) { + auto &group = groups_[next_group_]; + if (group.is_sparse_) { + VLOG(3) << "sparse group [" << next_group_ << "] start allreduce..."; + parallel_ctx_->AllReduceByStream(*group.sparse_contents_, + group.sparse_contents_, 0, false); + } else { + VLOG(3) << "dense group [" << next_group_ << "] start allreduce..."; + // Select common commstream to concat tensors + // group.dense_tensors ---> group.dense_contents_ + group.ConcatTensors(*parallel_ctx_->GetDeviceContext(0)); + + // Start allreduce + parallel_ctx_->AllReduceByStream(group.dense_contents_, + &(group.dense_contents_), 0, false); + // Select common commstream to split tensors + // group.dense_contents_ ---> group.dense_tensors + group.SplitTensors(*parallel_ctx_->GetDeviceContext(0)); + } + } +} + +void Reducer::FinalizeBackward() { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(comm_enent_.get(), comm_stream_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamWaitEvent(compute_stream_, comm_enent_.get(), 0)); + VLOG(3) << "In the batch, Reducer is finished..."; +} + +// According to the size of each parameter, it is allocated to different groups. +// The sparse parameter occupies a group exclusively. The dense parameters of +// the same data type are assigned to the same group. When dividing groups, the +// size of each group will be limited according to each value in +// group_size_limits in turn. When it is not enough, it will be divided +// by the last value of group_size_limits. The limit value is 0, which +// means that the parameter will monopolize the group. +std::vector> AssignGroupBySize( + const std::vector> &vars, + const std::vector &is_sparse_gradient, + const std::vector &group_size_limits) { + PADDLE_ENFORCE_EQ(vars.size(), is_sparse_gradient.size(), + platform::errors::PreconditionNotMet( + "vars len must be equal to is_sparse_gradient len, but " + "[%lu] != [%lu]", + vars.size(), is_sparse_gradient.size())); + // the return vector + std::vector> res; + + // Key: the var type + // Value: should use which index in group_size_limits for group size limit + std::unordered_map group_limit_index; + + // Key: the var type + // Value: + std::unordered_map, size_t>> + next_group; + + for (size_t i = 0; i < vars.size(); ++i) { + const auto &var = vars[i]; + if (is_sparse_gradient[i]) { + // we keep sparse var a single group + res.push_back({i}); + continue; + } + + const auto &var_dtype = var->DataType(); + const auto var_dtype_str = framework::DataTypeToString(var_dtype); + VLOG(3) << "var[" << var->GradVarName() << "] 's type is " + << var->DataType(); + auto &group_info = next_group[var_dtype_str]; + int64_t var_size = -1; + if (var->Var().IsType()) { + var_size = var->Var().Get().numel(); + } else { + VLOG(3) << "var " << var->Name() + << " is not tensor or selected_rows, so skip it"; + continue; + } + group_info.first.push_back(i); + group_info.second += framework::SizeOfType(var_dtype) * var_size; + + if (group_limit_index.find(var_dtype_str) == group_limit_index.end()) { + // means it is the first var of var_dtype + group_limit_index[var_dtype_str] = 0; + } + auto &cur_limit_index = group_limit_index[var_dtype_str]; + if (group_info.second >= group_size_limits[cur_limit_index]) { + // exceed group capacity and create a new group + res.emplace_back(std::move(group_info.first)); + group_info = std::pair, size_t>(); + cur_limit_index = + (std::min)(cur_limit_index + 1, group_size_limits.size() - 1); + } + } + + // add the final groups + for (auto &e : next_group) { + auto &group_info = e.second; + if (!group_info.first.empty()) { + res.emplace_back(std::move(group_info.first)); + } + } + + for (const auto &group_index : res) { + PADDLE_ENFORCE_NE( + group_index.empty(), true, + platform::errors::PreconditionNotMet( + "AssignGroupBySize construct empty group, please check.")); + } + std::sort(res.begin(), res.end(), + [](const std::vector &x, const std::vector &y) { + return x.front() < y.front(); + }); + return res; +} +#endif + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/reducer.h b/paddle/fluid/imperative/reducer.h new file mode 100644 index 00000000000..5e38f8abb18 --- /dev/null +++ b/paddle/fluid/imperative/reducer.h @@ -0,0 +1,225 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/imperative/layer.h" +#include "paddle/fluid/imperative/variable_wrapper.h" +#include "paddle/fluid/memory/memory.h" + +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/imperative/all_reduce.h" +#include "paddle/fluid/operators/math/concat_and_split.h" +#include "paddle/fluid/operators/strided_memcpy.h" +#include "paddle/fluid/platform/cuda_resource_pool.h" +#endif + +namespace paddle { +namespace imperative { + +#if defined(PADDLE_WITH_NCCL) +template +void ConcatTensorsForAllReduce( + const platform::CUDADeviceContext& context, + const std::vector& dense_tensors_, + framework::Variable* p_dense_contents) { + operators::math::ConcatFunctor + concat_functor_; + concat_functor_(context, dense_tensors_, 0, + p_dense_contents->GetMutable()); +} + +template +void SplitTensorsForAllReduce(const platform::CUDADeviceContext& context, + framework::Variable* p_dense_contents, + std::vector* p_dense_tensors) { + auto* in = p_dense_contents->GetMutable(); + std::vector outs; + std::vector shape_refer; + + outs.reserve(p_dense_tensors->size()); + shape_refer.reserve(p_dense_tensors->size()); + + for (auto& tensor : *p_dense_tensors) { + outs.emplace_back(&tensor); + shape_refer.emplace_back(&tensor); + } + // Sometimes direct copies will be faster + if (p_dense_tensors->size() < 10) { + operators::StridedMemcpyWithAxis0(context, *in, shape_refer, &outs); + } else { + operators::math::SplitFunctor + split_functor_; + split_functor_(context, *in, shape_refer, 0, &outs); + } +} + +class Group { + public: + // Here, we use dense_contents_ & sparse_contents_ to + // achieve the tensor fuse. When is_sparse_ is true, sparse_contents_ work, + // conversely, dense_contents_ works. It is mutex relationship. + framework::Variable dense_contents_; + framework::Variable* sparse_contents_ = nullptr; + bool is_sparse_ = false; + + // for concat kernel + std::vector dense_tensors_; + + std::vector length_; + // Global indices of participating variables in the group + std::vector variable_indices_; + + // Number of params that haven't been ready. When it is 0, it means + // the group is ready. + size_t pending_ = -1; + + // external message of group + framework::proto::VarType::Type dtype_; + + // context is used to select the stream for concat + void ConcatTensors(const platform::CUDADeviceContext& context) { + switch (dtype_) { + case framework::proto::VarType::FP16: + ConcatTensorsForAllReduce(context, dense_tensors_, + &dense_contents_); + break; + case framework::proto::VarType::FP32: + ConcatTensorsForAllReduce(context, dense_tensors_, + &dense_contents_); + break; + case framework::proto::VarType::FP64: + ConcatTensorsForAllReduce(context, dense_tensors_, + &dense_contents_); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type (%s) is not supported when it concats tensors for " + "allreduce.", + framework::DataTypeToString(dtype_))); + } + } + + // context is used to select the stream for split + void SplitTensors(const platform::CUDADeviceContext& context) { + switch (dtype_) { + case framework::proto::VarType::FP16: + SplitTensorsForAllReduce(context, &dense_contents_, + &dense_tensors_); + break; + case framework::proto::VarType::FP32: + SplitTensorsForAllReduce(context, &dense_contents_, + &dense_tensors_); + break; + case framework::proto::VarType::FP64: + SplitTensorsForAllReduce(context, &dense_contents_, + &dense_tensors_); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type (%s) is not supported when it splits tensors for " + "allreduce.", + framework::DataTypeToString(dtype_))); + } + } +}; + +struct VariableIndex { + // record the index in groups_ + size_t group_index; + size_t inside_group_index; +}; + +class Reducer { + public: + explicit Reducer( + const std::vector>& vars, + const std::vector>& group_indices, + const std::vector& is_sparse_gradient, + std::shared_ptr parallel_ctx); + + virtual ~Reducer() {} + + void InitializeGroups(const std::vector>& group_indices); + + int64_t InitializeDenseGroups(const std::vector& variable_indices_, + Group* p_group); + + void PrepareForBackward(); + + void AddDistHook(VariableWrapper* var_warpper, + const VariableIndex& var_index); + + void MarkVariableReady(const VariableIndex& var_index, + VariableWrapper* var_warpper); + + void MarkGroupReady(size_t group_index); + + void FinalizeBackward(); + + void ReleaseReducer(); + + // Reducer Singleton + static std::shared_ptr SetInstance( + const std::vector>& vars, + const std::vector>& group_indices, + const std::vector& is_sparse_gradient, + std::shared_ptr parallel_ctx) { + if (NULL == s_instance_) { + s_instance_.reset(new paddle::imperative::Reducer( + vars, group_indices, is_sparse_gradient, parallel_ctx)); + } + return s_instance_; + } + + static std::shared_ptr GetInstance() { + PADDLE_ENFORCE_EQ( + s_instance_ != NULL, true, + platform::errors::InvalidArgument("Reducer is not initialized.")); + return s_instance_; + } + + private: + std::vector> vars_; + std::vector> group_indices_; + static std::shared_ptr s_instance_; + std::vector groups_; + size_t next_group_ = 0; + platform::Place place_; + std::once_flag once_flag_; + std::vector is_sparse_gradient_; + std::shared_ptr parallel_ctx_; + + std::vector> events_; + std::shared_ptr comm_enent_; + cudaStream_t compute_stream_; + cudaStream_t comm_stream_; +}; + +std::vector> AssignGroupBySize( + const std::vector>& tensors, + const std::vector& is_sparse_gradient, + const std::vector& group_size_limits); +#endif + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/platform/collective_helper.cc b/paddle/fluid/platform/collective_helper.cc index 54dac976276..d2d9b41fcce 100644 --- a/paddle/fluid/platform/collective_helper.cc +++ b/paddle/fluid/platform/collective_helper.cc @@ -42,6 +42,7 @@ class NCCLCommImpl : public NCCLComm { void set_dev_ctx(std::unique_ptr&& dev_ctx) { dev_ctx_ = std::move(dev_ctx); } + CUDADeviceContext* dev_context() const override { return dev_ctx_.get(); } private: int ring_id_; diff --git a/paddle/fluid/platform/collective_helper.h b/paddle/fluid/platform/collective_helper.h index cc19fd5ac49..d44199f309b 100644 --- a/paddle/fluid/platform/collective_helper.h +++ b/paddle/fluid/platform/collective_helper.h @@ -55,6 +55,7 @@ class NCCLComm { virtual int device_id() const = 0; virtual ncclComm_t comm() const = 0; virtual cudaStream_t stream() const = 0; + virtual CUDADeviceContext* dev_context() const = 0; virtual ~NCCLComm() = default; }; diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 6fd1b7e1d36..c25b692a4a0 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -5,6 +5,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp if (WITH_NCCL) set(PYBIND_DEPS ${PYBIND_DEPS} nccl_wrapper) + set(PYBIND_DEPS ${PYBIND_DEPS} reducer) endif() if(NOT WIN32) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 7e3e175c09e..303dcc0e0ab 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -36,6 +36,7 @@ limitations under the License. */ #include "paddle/fluid/imperative/nccl_context.h" #include "paddle/fluid/imperative/partial_grad_engine.h" #include "paddle/fluid/imperative/profiler.h" +#include "paddle/fluid/imperative/reducer.h" #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/memory/allocation/mmap_allocator.h" @@ -1232,13 +1233,33 @@ void BindImperative(py::module *m_ptr) { py::call_guard()); #if defined(PADDLE_WITH_NCCL) - py::class_ nccl_ctx(m, - "NCCLParallelContext"); - - nccl_ctx + py::class_>(m, + "ParallelContext"); + py::class_>( + m, "NCCLParallelContext") .def(py::init()) .def("init", [](imperative::NCCLParallelContext &self) { self.Init(); }); + + py::class_>( + m, "Reducer", R"DOC()DOC") + .def(py::init( + [](const std::vector> &vars, + const std::vector> &group_indices, + const std::vector &is_sparse_gradient, + std::shared_ptr parallel_ctx) { + return imperative::Reducer::SetInstance( + vars, group_indices, is_sparse_gradient, parallel_ctx); + })) + .def("prepare_for_backward", &imperative::Reducer::PrepareForBackward, + py::call_guard()); + + m.def("assign_group_by_size", &imperative::AssignGroupBySize, py::arg("vars"), + py::arg("is_sparse_gradient"), + py::arg("group_size_limits") = std::vector{25 * 1024 * 1024}, + py::call_guard()); #endif } diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 3d26841876b..4db7f70e3cf 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -587,12 +587,19 @@ class Fleet(object): return self @dygraph_only - def distributed_model(self, model): + def distributed_model(self, model, group_size_limits=25, + small_group_size=1): """ Return distributed data parallel model (Only work in dygraph mode) Args: model (Layer): the user-defind model which inherits Layer. + group_size_limits(int, optional): It is up limited memory size(MB) of one group + parameters' gradient which is the input of communication + calling(e.g NCCLAllReduce). Default: 25. + small_group_size(int, optional): It is up limited memory size(MB) of last group in communication + calling. Making the last group small is useful to + improve performance. Default: 1. Returns: distributed data parallel model which inherits Layer. @@ -646,7 +653,10 @@ class Fleet(object): """ assert model is not None - self.model = paddle.DataParallel(model) + self.model = paddle.DataParallel( + model, + group_size_limits=group_size_limits, + small_group_size=small_group_size) return self.model @dygraph_only diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 83b6cf34134..46fdf05d0dd 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -24,6 +24,8 @@ from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import parallel_helper from paddle.fluid.dygraph import to_variable, no_grad from paddle.utils import deprecated +from paddle.fluid.dygraph import nn +import warnings __all__ = ["prepare_context", "ParallelEnv", "DataParallel"] @@ -284,58 +286,6 @@ def scale_loss(loss): return scaled_loss -@no_grad -def apply_collective_grads(parameters): - if not ParallelEnv().world_size > 1: - return - - grad_var_set = set() - grad_vars = [] - sparse_grad_vars = [] - strategy = _build_default_parallel_strategy() - for param in parameters: - # NOTE(zcd): The grad_ivar maybe no generated. - if param.trainable and (param._grad_ivar() is not None): - g_var = param._grad_ivar() - if g_var._is_sparse(): - sparse_grad_vars.append(g_var) - continue - grad_vars.append(g_var) - assert g_var not in grad_var_set - grad_var_set.add(g_var) - - if sparse_grad_vars: - sparse_grad_vars.sort(key=lambda x: x.name) - for grad_var in sparse_grad_vars: - grad_var._allreduce(strategy) - - # FIXME(zcd): the type of the var should be LoDTensor, i.e - # the gradients should be dense, otherwise, the following - # logic should be updated. - # 128 MB as a group - mega_bytes = 128 * 1024 * 1024 - group_idx = 0 - memory_counter = 0 - grad_var_groups = OrderedDict() - dtype = grad_vars[0].dtype - for g_var in grad_vars: - # NOTE: the dtype of the same group should be the same. - bytes = np.prod(g_var.shape) * core.size_of_dtype(g_var.dtype) - if memory_counter < mega_bytes and dtype == g_var.dtype: - memory_counter += bytes - else: - memory_counter = bytes - group_idx += 1 - grad_var_groups.setdefault(group_idx, []).append(g_var) - - coalesced_grads_and_vars = _coalesce_tensors(grad_var_groups) - - for coalesced_grad, _, _ in coalesced_grads_and_vars: - coalesced_grad._allreduce(strategy) - - _split_tensors(coalesced_grads_and_vars) - - class DataParallel(layers.Layer): """ Run the dygraph module with data parallelism. @@ -359,6 +309,12 @@ class DataParallel(layers.Layer): layers(Layer): The module that should be executed by data parallel. strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism, contains environment configuration related to parallel execution. Default: None. + group_size_limits(int, optional): It is up limited memory size(MB) of one group + parameters' gradient which is the input of communication + calling(e.g NCCLAllReduce). Default: 25. + small_group_size(int, optional): It is up limited memory size(MB) of last group in communication + calling. Making the last group small is useful to + improve performance. Default: 1. Returns: Layer: The data paralleled module. @@ -410,7 +366,11 @@ class DataParallel(layers.Layer): # train() """ - def __init__(self, layers, strategy=None): + def __init__(self, + layers, + strategy=None, + group_size_limits=25, + small_group_size=1): super(DataParallel, self).__init__(layers.full_name() + "_data_parallel") @@ -425,7 +385,67 @@ class DataParallel(layers.Layer): else: self._strategy = _build_default_parallel_strategy() + if self._strategy.nranks > 1: + self.group_size_limits = int(group_size_limits * 1024 * 1024) + # NOTE(shenliang03): We can set environment variables to control + # the size of the group, Default: 1MB. The role of this small group is: + # when the last group allreduce, the overlap cannot work. Making the + # the last group small is useful to improve performance. + self.small_group_size = int(small_group_size * 1024 * 1024) + self.init_reducer() + else: + warnings.warn( + "nranks is less than 2, " + "maybe you need to check the current system environment." + " Need to use spawn or fleetrun to " + "start distributed programs.") + + def init_reducer(self): + layers_param = [] + params_set = set() + for sublayer in self.sublayers(): + for _, param in sublayer.named_parameters(include_sublayers=False): + if param is None or param in params_set: + continue + params_set.add(param) + if not isinstance(param, core.VarBase): + raise TypeError("The data type of '%s' must be Varbase" % + param.name) + if param.trainable: + layers_param.append((sublayer, param)) + + trainable_parameters = [param for _, param in layers_param] + + # NOTE(shenliang03): Here we can only use the attributes to judge whether + # parameter is sparse(or SelectedRows). The reason is that the sparse message + # can't be obtained when bp hasn't happened yet. So if layer supports sparse parameter, + # we should add the layer here like "nn.Embedding". + def check_layer_sparse(sublayer): + if isinstance(sublayer, nn.Embedding): + return sublayer._is_sparse + return False + + is_sparse_gradient = [ + check_layer_sparse(sublayer) for sublayer, _ in layers_param + ] + + self.group_indices = core.assign_group_by_size( + trainable_parameters, is_sparse_gradient, + [self.small_group_size, self.group_size_limits]) + + assert parallel_helper.__parallel_ctx__clz__ is not None, \ + "ParallelContext must be initialized before. You should use init_parallel_env() before" \ + "constructing the DataParallel." + + self._reducer = core.Reducer(trainable_parameters, + list(reversed(self.group_indices)), + is_sparse_gradient, + parallel_helper.__parallel_ctx__clz__) + def forward(self, *inputs, **kwargs): + if self._strategy.nranks > 1: + self._reducer.prepare_for_backward() + return self._layers(*inputs, **kwargs) @deprecated( diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 2d95bfa8c54..f3c4984e29e 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -22,7 +22,6 @@ from collections import defaultdict import paddle from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard -from paddle.fluid.dygraph.parallel import apply_collective_grads from . import framework from . import layers @@ -772,9 +771,6 @@ class Optimizer(object): parameter_list = parameter_list if parameter_list \ else self._parameter_list - if paddle.distributed.get_world_size() > 1: - apply_collective_grads(parameter_list) - params_grads = [] for param in parameter_list: if not param.trainable: diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 2bb3b45bc41..1ddafa97a50 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -151,6 +151,10 @@ if (WITH_NCCL) endif() endif() +if(NOT WITH_NCCL) + list(REMOVE_ITEM TEST_OPS test_imperative_group) +endif() + if(NOT WITH_GPU OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_boxps) endif() diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding.py index e0b833df0c0..226f1293ef6 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding.py @@ -30,7 +30,8 @@ class SimpleNet(fluid.Layer): vocab_size, num_steps=20, init_scale=0.1, - is_sparse=False): + is_sparse=False, + dtype="float32"): super(SimpleNet, self).__init__() self.hidden_size = hidden_size self.vocab_size = vocab_size @@ -38,7 +39,7 @@ class SimpleNet(fluid.Layer): self.num_steps = num_steps self.embedding = Embedding( size=[self.vocab_size, self.hidden_size], - dtype='float32', + dtype=dtype, is_sparse=is_sparse, param_attr=fluid.ParamAttr( name='embedding_param', @@ -47,13 +48,13 @@ class SimpleNet(fluid.Layer): self.softmax_weight = self.create_parameter( attr=fluid.ParamAttr(), shape=[self.hidden_size, self.vocab_size], - dtype="float32", + dtype=dtype, default_initializer=fluid.initializer.UniformInitializer( low=-self.init_scale, high=self.init_scale)) self.softmax_bias = self.create_parameter( attr=fluid.ParamAttr(), shape=[self.vocab_size], - dtype="float32", + dtype=dtype, default_initializer=fluid.initializer.UniformInitializer( low=-self.init_scale, high=self.init_scale)) diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py new file mode 100644 index 00000000000..e7b4e605253 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np + +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Embedding +from paddle.fluid.dygraph.base import to_variable + +from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase +from parallel_dygraph_sparse_embedding import SimpleNet, fake_sample_reader, TestSparseEmbedding + +# global configs +batch_size = 4 +batch_num = 200 +hidden_size = 10 +vocab_size = 1000 +num_steps = 3 +init_scale = 0.1 + + +class TestSparseEmbeddingFP64(TestSparseEmbedding): + def get_model(self): + model = SimpleNet( + hidden_size=hidden_size, + vocab_size=vocab_size, + num_steps=num_steps, + init_scale=init_scale, + is_sparse=True, + dtype="float64") + + train_reader = paddle.batch( + fake_sample_reader(), batch_size=batch_size, drop_last=True) + + optimizer = fluid.optimizer.SGD(learning_rate=0.001, + parameter_list=model.parameters()) + + return model, train_reader, optimizer + + +if __name__ == "__main__": + runtime_main(TestSparseEmbeddingFP64) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_base.py b/python/paddle/fluid/tests/unittests/test_fleet_base.py index f50d80d215d..99986043ec7 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_base.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_base.py @@ -160,15 +160,20 @@ class TestFleetDygraph(unittest.TestCase): learning_rate=0.01, parameters=layer.parameters()) # remove init cause this UT cannot launch distributed task adam = fleet.distributed_optimizer(adam) - dp_layer = fleet.distributed_model(layer) - lr = 0.001 - adam.set_lr(lr) - cur_lr = adam.get_lr() - assert (lr == cur_lr) - state_dict = adam.state_dict() - adam.set_state_dict(state_dict) - - final_strategy = fleet._final_strategy() + try: + dp_layer = fleet.distributed_model(layer) + except Exception as e: + # This is just for testing the interface, + # and will not actually be called. Therefore, + # use "try-except" to avoid errors. + lr = 0.001 + adam.set_lr(lr) + cur_lr = adam.get_lr() + assert (lr == cur_lr) + state_dict = adam.state_dict() + adam.set_state_dict(state_dict) + + final_strategy = fleet._final_strategy() class TestFleetBaseSingleError(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_base_single.py b/python/paddle/fluid/tests/unittests/test_fleet_base_single.py index 111a6331958..03e29399482 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_base_single.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_base_single.py @@ -62,7 +62,6 @@ class TestFleetDygraphSingle(unittest.TestCase): loss = loss_fn(outputs, labels) loss = dp_layer.scale_loss(loss) loss.backward() - dp_layer.apply_collective_grads() adam.step() adam.clear_grad() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_group.py b/python/paddle/fluid/tests/unittests/test_imperative_group.py new file mode 100644 index 00000000000..299efa6d9c1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_group.py @@ -0,0 +1,160 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import contextlib +import unittest +import numpy as np +import six +import unittest + +import paddle +import paddle.fluid as fluid +import paddle.fluid.dygraph as dygraph +from paddle.fluid.dygraph.nn import Linear +import paddle.fluid.core as core +from paddle.fluid.optimizer import SGDOptimizer + + +class MLP(fluid.Layer): + def __init__(self, param_attr=None, bias_attr=None): + super(MLP, self).__init__() + + self._linear1 = Linear(784, 10) + self._linear2 = Linear(10, 10) + + def forward(self, inputs): + y = self._linear1(inputs) + y = self._linear2(y) + return y + + +class TestDataParallelGroup(unittest.TestCase): + def create_varbase(self, dtype, shape, + type=core.VarDesc.VarType.LOD_TENSOR): + return core.VarBase(dtype, shape, "", type, True) + + def test_construct_group0(self): + # one dtype & one limit capability + var_list = [] + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 50])) + var_list.append( + self.create_varbase(core.VarDesc.VarType.FP32, [2, 100])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 25])) + res = core.assign_group_by_size(var_list, [False, False, False, False], + [400]) + self.assertEqual([[0], [1], [2], [3]], res) + + def test_construct_group1(self): + # multi dtype & one limit capability + var_list = [] + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [1, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP64, [1, 25])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [1, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP64, [1, 25])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [1, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP64, [1, 25])) + res = core.assign_group_by_size( + var_list, [False, False, False, False, False, False], [400]) + self.assertEqual([[0, 2], [1, 3], [4], [5]], res) + + def test_construct_group2(self): + # one dtype & multi limit capability + var_list = [] + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [2, 50])) + res = core.assign_group_by_size(var_list, [False, False, False, False], + [400, 800]) + self.assertEqual([[0], [1, 2], [3]], res) + + def test_construct_group3(self): + # multi dtype & multi limit capability + var_list = [] + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [1, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP64, [1, 25])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [1, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP64, [1, 25])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [1, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP64, [1, 25])) + res = core.assign_group_by_size( + var_list, [False, False, False, False, False, False], [200, 400]) + self.assertEqual([[0], [1], [2, 4], [3, 5]], res) + + def test_construct_group4(self): + # multi dtype & zero limit capability + var_list = [] + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [1, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP64, [1, 25])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [1, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP64, [1, 25])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [1, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP64, [1, 25])) + res = core.assign_group_by_size( + var_list, [False, False, False, False, False, False], [0]) + self.assertEqual([[0], [1], [2], [3], [4], [5]], res) + + def test_construct_group5(self): + # multi dtype & infinite capability + var_list = [] + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [1, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP64, [1, 25])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [1, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP64, [1, 25])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [1, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP64, [1, 25])) + res = core.assign_group_by_size( + var_list, [False, False, False, False, False, False], [10000]) + self.assertEqual([[0, 2, 4], [1, 3, 5]], res) + + def test_construct_group6(self): + # multi dtype & limit capability & multi tensor type + var_list = [] + var_list.append( + self.create_varbase(core.VarDesc.VarType.FP32, [1, 50], + core.VarDesc.VarType.SELECTED_ROWS)) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP64, [1, 25])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [1, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP64, [1, 25])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [1, 50])) + var_list.append( + self.create_varbase(core.VarDesc.VarType.FP64, [1, 25], + core.VarDesc.VarType.SELECTED_ROWS)) + res = core.assign_group_by_size( + var_list, [True, False, False, False, False, True], [400]) + self.assertEqual([[0], [1, 3], [2, 4], [5]], res) + + def test_construct_group7(self): + # multi dtype & multi limit capability & multi tensor type + var_list = [] + var_list.append( + self.create_varbase(core.VarDesc.VarType.FP32, [1, 50], + core.VarDesc.VarType.SELECTED_ROWS)) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP64, [1, 25])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [1, 50])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP64, [1, 25])) + var_list.append(self.create_varbase(core.VarDesc.VarType.FP32, [1, 50])) + var_list.append( + self.create_varbase(core.VarDesc.VarType.FP64, [1, 25], + core.VarDesc.VarType.SELECTED_ROWS)) + res = core.assign_group_by_size( + var_list, [True, False, False, False, False, True], [200, 400]) + self.assertEqual([[0], [1], [2], [3], [4], [5]], res) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding.py index 7f051f1005c..43907da6098 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding.py @@ -22,6 +22,7 @@ import paddle.fluid as fluid from test_dist_base import TestDistBase from spawn_runner_base import TestDistSpawnRunner from parallel_dygraph_sparse_embedding import TestSparseEmbedding +from parallel_dygraph_sparse_embedding_fp64 import TestSparseEmbeddingFP64 flag_name = os.path.splitext(__file__)[0] @@ -41,6 +42,21 @@ class TestParallelDygraphSparseEmdedding(TestDistBase): log_name=flag_name) +class TestParallelDygraphSparseEmdeddingFP64(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._dygraph = True + + def test_sparse_embedding_fp64(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_sparse_embedding_fp64.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + class TestParallelDygraphSparseEmdeddingSpawn(TestDistSpawnRunner): def test_sparse_embedding_with_spawn(self): if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index a81a4d7faa7..7c731c40029 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -49,6 +49,7 @@ from paddle.fluid.executor import scope_guard, Executor from paddle.fluid.dygraph.layers import Layer from paddle.metric import Metric from paddle.static import InputSpec as Input +import paddle.distributed as dist from .callbacks import config_callbacks, EarlyStopping from .model_summary import summary @@ -886,6 +887,7 @@ class Model(object): # init backend if fluid.in_dygraph_mode(): + dist.init_parallel_env() self._adapter = DynamicGraphAdapter(self) else: self._adapter = StaticGraphAdapter(self) @@ -1270,7 +1272,6 @@ class Model(object): fluid.default_main_program().random_seed = main_prog_seed fluid.default_startup_program( ).random_seed = startup_prog_seed - fluid.dygraph.parallel.prepare_context() else: prepare_distributed_context(self._place) _parallel_context_initialized = True diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 37510231219..910c9b185db 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -18,7 +18,6 @@ from ..fluid import framework from ..fluid.framework import Variable import paddle -from paddle.fluid.dygraph.parallel import apply_collective_grads __all__ = ["Adam"] @@ -271,9 +270,6 @@ class Adam(Optimizer): adam.step() adam.clear_grad() """ - if paddle.distributed.get_world_size() > 1: - apply_collective_grads(self._parameter_list) - self._dtype = None params_grads = [] for param in self._parameter_list: diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index b597109d314..2aa7fa115ec 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -17,7 +17,6 @@ from .adam import Adam from ..fluid import framework from ..fluid.dygraph import base as imperative_base import paddle -from paddle.fluid.dygraph.parallel import apply_collective_grads __all__ = ['AdamW'] @@ -211,9 +210,6 @@ class AdamW(Adam): @framework.dygraph_only @imperative_base.no_grad def step(self): - if paddle.distributed.get_world_size() > 1: - apply_collective_grads(self._parameter_list) - self._dtype = None params_grads = [] for param in self._parameter_list: diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 030d419de48..295821a93cd 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -22,7 +22,6 @@ from collections import defaultdict import paddle from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard -from paddle.fluid.dygraph.parallel import apply_collective_grads from ..fluid import framework from ..fluid import layers @@ -681,9 +680,6 @@ class Optimizer(object): parameter_list = parameters if parameters \ else self._parameter_list - if paddle.distributed.get_world_size() > 1: - apply_collective_grads(parameter_list) - params_grads = [] for param in parameter_list: if not param.trainable: @@ -912,9 +908,6 @@ class Optimizer(object): adam.step() adam.clear_grad() """ - if paddle.distributed.get_world_size() > 1: - apply_collective_grads(self._parameter_list) - self._dtype = None params_grads = [] for param in self._parameter_list: -- GitLab