未验证 提交 57d5ffa5 编写于 作者: H Haohongxiang 提交者: GitHub

[Dygraph] Fix memory bugs of no sync and SplitTensors in DataParallel (#47369)

* fix no sync bugs

* update

* update task chain

fix: update wait chain

feat: add `GetDeviceContext` for gloo

* fix oom

* fix dev

* update

* update
Co-authored-by: NLiYuRio <liyuruijx@163.com>
Co-authored-by: NForFishes <2282912238@qq.com>
上级 6baeb2d1
...@@ -41,6 +41,8 @@ bool ProcessGroup::Task::Wait(std::chrono::milliseconds timeout) { ...@@ -41,6 +41,8 @@ bool ProcessGroup::Task::Wait(std::chrono::milliseconds timeout) {
void ProcessGroup::Task::Synchronize() {} void ProcessGroup::Task::Synchronize() {}
void ProcessGroup::Task::UpdateWaitChain(const phi::DeviceContext& ctx) {}
ProcessGroup::ProcessGroup(int rank, ProcessGroup::ProcessGroup(int rank,
int size, int size,
const platform::Place& place, const platform::Place& place,
......
...@@ -66,6 +66,7 @@ class ProcessGroup { ...@@ -66,6 +66,7 @@ class ProcessGroup {
virtual bool IsCompleted(); virtual bool IsCompleted();
virtual bool Wait(std::chrono::milliseconds timeout = kWaitTimeout); virtual bool Wait(std::chrono::milliseconds timeout = kWaitTimeout);
virtual void Synchronize(); virtual void Synchronize();
virtual void UpdateWaitChain(const phi::DeviceContext& ctx);
bool IsSync() const { return sync_op_; } bool IsSync() const { return sync_op_; }
protected: protected:
...@@ -92,7 +93,7 @@ class ProcessGroup { ...@@ -92,7 +93,7 @@ class ProcessGroup {
int GetSize() const { return size_; } int GetSize() const { return size_; }
virtual const std::string GetBackendName() const = 0; virtual const std::string GetBackendName() const = 0;
virtual phi::DeviceContext* GetDeviceContext(const Place& place) const { virtual const phi::DeviceContext& GetDeviceContext(const Place& place) const {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Does not support to get device_context from ProcessGroup%s.", "Does not support to get device_context from ProcessGroup%s.",
GetBackendName())); GetBackendName()));
......
...@@ -150,6 +150,11 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -150,6 +150,11 @@ class ProcessGroupGloo : public ProcessGroup {
return GLOO_BACKEND_NAME; return GLOO_BACKEND_NAME;
} }
const phi::DeviceContext& GetDeviceContext(
const Place& place) const override {
return *platform::DeviceContextPool::Instance().Get(place);
}
// Helper functions for Gloo. // Helper functions for Gloo.
static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname( static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname(
const std::string& hostname); const std::string& hostname);
......
...@@ -110,6 +110,11 @@ bool ProcessGroupNCCL::NCCLTask::IsCompleted() { ...@@ -110,6 +110,11 @@ bool ProcessGroupNCCL::NCCLTask::IsCompleted() {
return true; return true;
} }
void ProcessGroupNCCL::NCCLTask::UpdateWaitChain(
const phi::DeviceContext& ctx) {
control_events_[0].Record(*static_cast<const phi::GPUContext*>(&ctx));
}
void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes, void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes,
std::vector<int64_t> tensor_shape) { std::vector<int64_t> tensor_shape) {
int64_t len_size = (*split_sizes).size(); int64_t len_size = (*split_sizes).size();
...@@ -1591,15 +1596,15 @@ ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const { ...@@ -1591,15 +1596,15 @@ ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
return iter->second[0]->GetNcclComm(); return iter->second[0]->GetNcclComm();
} }
phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
const Place& place) const { const Place& place) const {
return GetDeviceContext(place, /*use_calc_stream*/ false); return GetDeviceContext(place, /*use_calc_stream*/ false);
} }
phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
const Place& place, bool use_calc_stream) const { const Place& place, bool use_calc_stream) const {
if (use_calc_stream) { if (use_calc_stream) {
return platform::DeviceContextPool::Instance().Get(place); return *platform::DeviceContextPool::Instance().Get(place);
} else { } else {
std::vector<Place> places = {place}; std::vector<Place> places = {place};
const auto& iter = places_to_ctx_.find(GetKeyFromPlaces(places)); const auto& iter = places_to_ctx_.find(GetKeyFromPlaces(places));
...@@ -1607,7 +1612,7 @@ phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( ...@@ -1607,7 +1612,7 @@ phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
places_to_ctx_.end(), places_to_ctx_.end(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Cannot find device context in process group.")); "Cannot find device context in process group."));
return iter->second[0].get(); return *iter->second[0];
} }
} }
......
...@@ -75,6 +75,8 @@ class ProcessGroupNCCL : public ProcessGroupStream { ...@@ -75,6 +75,8 @@ class ProcessGroupNCCL : public ProcessGroupStream {
virtual ~NCCLTask(); virtual ~NCCLTask();
void UpdateWaitChain(const phi::DeviceContext& ctx) override;
std::vector<EventManager> control_events_; std::vector<EventManager> control_events_;
std::vector<phi::DenseTensor> barrierTensors_; std::vector<phi::DenseTensor> barrierTensors_;
...@@ -96,10 +98,10 @@ class ProcessGroupNCCL : public ProcessGroupStream { ...@@ -96,10 +98,10 @@ class ProcessGroupNCCL : public ProcessGroupStream {
return std::string(NCCL_BACKEND_NAME); return std::string(NCCL_BACKEND_NAME);
} }
phi::DeviceContext* GetDeviceContext(const Place& place) const override; const phi::DeviceContext& GetDeviceContext(const Place& place) const override;
phi::DeviceContext* GetDeviceContext(const Place& place, const phi::DeviceContext& GetDeviceContext(
bool use_calc_stream) const override; const Place& place, bool use_calc_stream) const override;
std::shared_ptr<ProcessGroup::Task> AllReduce( std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT std::vector<phi::DenseTensor>& in_tensors, // NOLINT
......
...@@ -23,7 +23,7 @@ ProcessGroupStream::ProcessGroupStream(int rank, ...@@ -23,7 +23,7 @@ ProcessGroupStream::ProcessGroupStream(int rank,
int gid) int gid)
: ProcessGroup(rank, size, place, gid) {} : ProcessGroup(rank, size, place, gid) {}
phi::DeviceContext* ProcessGroupStream::GetDeviceContext( const phi::DeviceContext& ProcessGroupStream::GetDeviceContext(
const Place& place, bool use_calc_stream) const { const Place& place, bool use_calc_stream) const {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support get device_context.", GetBackendName())); "ProcessGroup%s does not support get device_context.", GetBackendName()));
......
...@@ -54,8 +54,8 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -54,8 +54,8 @@ class ProcessGroupStream : public ProcessGroup {
ProcessGroupStream(int rank, int size, const platform::Place& place, int gid); ProcessGroupStream(int rank, int size, const platform::Place& place, int gid);
virtual ~ProcessGroupStream() = default; virtual ~ProcessGroupStream() = default;
virtual phi::DeviceContext* GetDeviceContext(const Place& place, virtual const phi::DeviceContext& GetDeviceContext(
bool use_calc_stream) const; const Place& place, bool use_calc_stream) const;
std::shared_ptr<ProcessGroup::Task> AllGather( std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT std::vector<phi::DenseTensor>& in_tensors, // NOLINT
......
...@@ -25,18 +25,18 @@ namespace distributed { ...@@ -25,18 +25,18 @@ namespace distributed {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct ConcatDenseTensor { struct ConcatDenseTensor {
void operator()(const DeviceContext *context, void operator()(const DeviceContext &context,
const std::vector<phi::DenseTensor> &in, const std::vector<phi::DenseTensor> &in,
phi::DenseTensor *out, phi::DenseTensor *out,
int axis = 0) { int axis = 0) {
phi::funcs::ConcatFunctor<DeviceContext, T> concat_functor; phi::funcs::ConcatFunctor<DeviceContext, T> concat_functor;
concat_functor(*context, in, axis, out); concat_functor(context, in, axis, out);
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct SplitDenseTensor { struct SplitDenseTensor {
void operator()(const DeviceContext *context, void operator()(const DeviceContext &context,
const phi::DenseTensor &in, const phi::DenseTensor &in,
std::vector<phi::DenseTensor *> *out, std::vector<phi::DenseTensor *> *out,
int axis = 0) { int axis = 0) {
...@@ -46,19 +46,19 @@ struct SplitDenseTensor { ...@@ -46,19 +46,19 @@ struct SplitDenseTensor {
shape_refer.emplace_back(p_tensor); shape_refer.emplace_back(p_tensor);
} }
phi::funcs::SplitFunctor<DeviceContext, T> split_functor; phi::funcs::SplitFunctor<DeviceContext, T> split_functor;
split_functor(*context, in, shape_refer, axis, out); split_functor(context, in, shape_refer, axis, out);
} }
}; };
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T> template <typename T>
struct ConcatDenseTensor<platform::CustomDeviceContext, T> { struct ConcatDenseTensor<platform::CustomDeviceContext, T> {
void operator()(const platform::CustomDeviceContext *context, void operator()(const platform::CustomDeviceContext &context,
const std::vector<phi::DenseTensor> &in, const std::vector<phi::DenseTensor> &in,
phi::DenseTensor *out, phi::DenseTensor *out,
int axis = 0) { int axis = 0) {
auto *out_data = out->data<T>(); auto *out_data = out->data<T>();
auto *device = phi::DeviceManager::GetDeviceWithPlace(context->GetPlace()); auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace());
size_t offset = 0; size_t offset = 0;
for (const auto &tensor : in) { for (const auto &tensor : in) {
const auto *in_data = tensor.data<T>(); const auto *in_data = tensor.data<T>();
...@@ -71,12 +71,12 @@ struct ConcatDenseTensor<platform::CustomDeviceContext, T> { ...@@ -71,12 +71,12 @@ struct ConcatDenseTensor<platform::CustomDeviceContext, T> {
template <typename T> template <typename T>
struct SplitDenseTensor<platform::CustomDeviceContext, T> { struct SplitDenseTensor<platform::CustomDeviceContext, T> {
void operator()(const platform::CustomDeviceContext *context, void operator()(const platform::CustomDeviceContext &context,
const phi::DenseTensor &in, const phi::DenseTensor &in,
std::vector<phi::DenseTensor *> *out, std::vector<phi::DenseTensor *> *out,
int axis = 0) { int axis = 0) {
auto *in_data = in.data<T>(); auto *in_data = in.data<T>();
auto *device = phi::DeviceManager::GetDeviceWithPlace(context->GetPlace()); auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace());
size_t offset = 0; size_t offset = 0;
for (auto *p_tensor : *out) { for (auto *p_tensor : *out) {
auto *out_data = p_tensor->data<T>(); auto *out_data = p_tensor->data<T>();
...@@ -89,7 +89,7 @@ struct SplitDenseTensor<platform::CustomDeviceContext, T> { ...@@ -89,7 +89,7 @@ struct SplitDenseTensor<platform::CustomDeviceContext, T> {
#endif #endif
template <typename DeviceContext> template <typename DeviceContext>
void ConcatDenseTensorWithType(const DeviceContext *dev_ctx, void ConcatDenseTensorWithType(const DeviceContext &dev_ctx,
const std::vector<phi::DenseTensor> &t_list, const std::vector<phi::DenseTensor> &t_list,
phi::DenseTensor *p_out, phi::DenseTensor *p_out,
phi::DataType type) { phi::DataType type) {
...@@ -126,7 +126,7 @@ void ConcatDenseTensorWithType(const DeviceContext *dev_ctx, ...@@ -126,7 +126,7 @@ void ConcatDenseTensorWithType(const DeviceContext *dev_ctx,
} }
template <typename DeviceContext> template <typename DeviceContext>
void SplitDenseTensorWithType(const DeviceContext *dev_ctx, void SplitDenseTensorWithType(const DeviceContext &dev_ctx,
const phi::DenseTensor &t_in, const phi::DenseTensor &t_in,
std::vector<phi::DenseTensor *> *p_list, std::vector<phi::DenseTensor *> *p_list,
phi::DataType type) { phi::DataType type) {
...@@ -162,16 +162,16 @@ void SplitDenseTensorWithType(const DeviceContext *dev_ctx, ...@@ -162,16 +162,16 @@ void SplitDenseTensorWithType(const DeviceContext *dev_ctx,
} }
} }
void ConcatTensor(const phi::DeviceContext *dev_ctx, void ConcatTensor(const phi::DeviceContext &dev_ctx,
const std::vector<phi::DenseTensor> &tensor_list, const std::vector<phi::DenseTensor> &tensor_list,
const experimental::Tensor *tensor) { const experimental::Tensor *tensor) {
auto *dense_tensor = auto *dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor->impl()).get(); std::dynamic_pointer_cast<phi::DenseTensor>(tensor->impl()).get();
const auto &place = dev_ctx->GetPlace(); const auto &place = dev_ctx.GetPlace();
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
ConcatDenseTensorWithType(static_cast<const phi::GPUContext *>(dev_ctx), ConcatDenseTensorWithType(static_cast<const phi::GPUContext &>(dev_ctx),
tensor_list, tensor_list,
dense_tensor, dense_tensor,
tensor->dtype()); tensor->dtype());
...@@ -183,7 +183,7 @@ void ConcatTensor(const phi::DeviceContext *dev_ctx, ...@@ -183,7 +183,7 @@ void ConcatTensor(const phi::DeviceContext *dev_ctx,
} else if (platform::is_custom_place(place)) { } else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
ConcatDenseTensorWithType( ConcatDenseTensorWithType(
static_cast<const platform::CustomDeviceContext *>(dev_ctx), static_cast<const platform::CustomDeviceContext &>(dev_ctx),
tensor_list, tensor_list,
dense_tensor, dense_tensor,
tensor->dtype()); tensor->dtype());
...@@ -194,7 +194,7 @@ void ConcatTensor(const phi::DeviceContext *dev_ctx, ...@@ -194,7 +194,7 @@ void ConcatTensor(const phi::DeviceContext *dev_ctx,
"CUSTOM_DEVICE support.")); "CUSTOM_DEVICE support."));
#endif #endif
} else if (platform::is_cpu_place(place)) { } else if (platform::is_cpu_place(place)) {
ConcatDenseTensorWithType(static_cast<const phi::CPUContext *>(dev_ctx), ConcatDenseTensorWithType(static_cast<const phi::CPUContext &>(dev_ctx),
tensor_list, tensor_list,
dense_tensor, dense_tensor,
tensor->dtype()); tensor->dtype());
...@@ -204,20 +204,20 @@ void ConcatTensor(const phi::DeviceContext *dev_ctx, ...@@ -204,20 +204,20 @@ void ConcatTensor(const phi::DeviceContext *dev_ctx,
} }
} }
void SplitTensor(const phi::DeviceContext *dev_ctx, void SplitTensor(const phi::DeviceContext &dev_ctx,
const phi::DenseTensor &tensor, const phi::DenseTensor &tensor,
const std::vector<experimental::Tensor> *tensor_list) { const std::vector<experimental::Tensor> *tensor_list) {
std::vector<phi::DenseTensor *> dense_list; std::vector<phi::DenseTensor *> dense_list;
for (auto &tensor : *tensor_list) { for (auto &tensor : *tensor_list) {
auto p_tensor = auto *p_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()).get(); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()).get();
dense_list.emplace_back(p_tensor); dense_list.emplace_back(p_tensor);
} }
const auto &place = dev_ctx->GetPlace(); const auto &place = dev_ctx.GetPlace();
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
SplitDenseTensorWithType(static_cast<const phi::GPUContext *>(dev_ctx), SplitDenseTensorWithType(static_cast<const phi::GPUContext &>(dev_ctx),
tensor, tensor,
&dense_list, &dense_list,
tensor.dtype()); tensor.dtype());
...@@ -229,7 +229,7 @@ void SplitTensor(const phi::DeviceContext *dev_ctx, ...@@ -229,7 +229,7 @@ void SplitTensor(const phi::DeviceContext *dev_ctx,
} else if (platform::is_custom_place(place)) { } else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
SplitDenseTensorWithType( SplitDenseTensorWithType(
static_cast<const platform::CustomDeviceContext *>(dev_ctx), static_cast<const platform::CustomDeviceContext &>(dev_ctx),
tensor, tensor,
&dense_list, &dense_list,
tensor.dtype()); tensor.dtype());
...@@ -239,7 +239,7 @@ void SplitTensor(const phi::DeviceContext *dev_ctx, ...@@ -239,7 +239,7 @@ void SplitTensor(const phi::DeviceContext *dev_ctx,
"please recompile or reinstall Paddle with CUSTOM_DEVICE support.")); "please recompile or reinstall Paddle with CUSTOM_DEVICE support."));
#endif #endif
} else if (platform::is_cpu_place(place)) { } else if (platform::is_cpu_place(place)) {
SplitDenseTensorWithType(static_cast<const phi::CPUContext *>(dev_ctx), SplitDenseTensorWithType(static_cast<const phi::CPUContext &>(dev_ctx),
tensor, tensor,
&dense_list, &dense_list,
tensor.dtype()); tensor.dtype());
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include "paddle/phi/backends/device_guard.h" #include "paddle/phi/backends/device_guard.h"
#include "paddle/phi/backends/device_manager.h" #include "paddle/phi/backends/device_manager.h"
DECLARE_bool(use_stream_safe_cuda_allocator);
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -335,13 +337,20 @@ void EagerGroup::ConcatTensors(const platform::Place &place) { ...@@ -335,13 +337,20 @@ void EagerGroup::ConcatTensors(const platform::Place &place) {
} }
} }
void EagerGroup::SplitTensors(const platform::Place &place) { void EagerGroup::SplitTensorsDev(const platform::DeviceContext &context) {
auto place = context.GetPlace();
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto *default_ctx = static_cast<phi::GPUContext *>( auto &gpu_context = static_cast<const phi::GPUContext &>(context);
platform::DeviceContextPool::Instance().Get(place));
SplitTensorsWithType( SplitTensorsWithType(
*default_ctx, &dense_contents_, &dense_tensors_, dtype_); gpu_context, &dense_contents_, &dense_tensors_, dtype_);
if (FLAGS_use_stream_safe_cuda_allocator) {
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(dense_contents_.impl());
VLOG(3) << "Free dense_contents_ " << dense_contents_.numel();
memory::RecordStream(dense_tensor->Holder(), gpu_context.stream());
dense_contents_.reset();
}
#else #else
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't split grad tensor since it's not compiled with NCCL," "Paddle can't split grad tensor since it's not compiled with NCCL,"
...@@ -349,10 +358,11 @@ void EagerGroup::SplitTensors(const platform::Place &place) { ...@@ -349,10 +358,11 @@ void EagerGroup::SplitTensors(const platform::Place &place) {
#endif #endif
} else if (platform::is_custom_place(place)) { } else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
auto *default_ctx = static_cast<platform::CustomDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
SplitTensorsWithType( SplitTensorsWithType(
*default_ctx, &dense_contents_, &dense_tensors_, dtype_); static_cast<const platform::CustomDeviceContext &>(context),
&dense_contents_,
&dense_tensors_,
dtype_);
#else #else
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't split grad tensor since it's not compiled with " "Paddle can't split grad tensor since it's not compiled with "
...@@ -360,10 +370,10 @@ void EagerGroup::SplitTensors(const platform::Place &place) { ...@@ -360,10 +370,10 @@ void EagerGroup::SplitTensors(const platform::Place &place) {
"Please recompile or reinstall Paddle with CUSTOM_DEVICE support.")); "Please recompile or reinstall Paddle with CUSTOM_DEVICE support."));
#endif #endif
} else if (platform::is_cpu_place(place)) { } else if (platform::is_cpu_place(place)) {
auto *default_ctx = static_cast<phi::CPUContext *>( SplitTensorsWithType(static_cast<const phi::CPUContext &>(context),
platform::DeviceContextPool::Instance().Get(place)); &dense_contents_,
SplitTensorsWithType( &dense_tensors_,
*default_ctx, &dense_contents_, &dense_tensors_, dtype_); dtype_);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Split grad tensor not supported on place (%s)", place)); "Split grad tensor not supported on place (%s)", place));
...@@ -578,9 +588,11 @@ void EagerReducer::TraverseBackwardGraph(const std::vector<Tensor> &outputs) { ...@@ -578,9 +588,11 @@ void EagerReducer::TraverseBackwardGraph(const std::vector<Tensor> &outputs) {
} }
} }
void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) { void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs,
const bool is_sync) {
VLOG(3) << "after forward, then reset count for backward."; VLOG(3) << "after forward, then reset count for backward.";
grad_need_hooks_ = true; grad_need_hooks_ = is_sync;
next_group_ = 0; next_group_ = 0;
std::for_each(groups_.begin(), groups_.end(), [](EagerGroup &group) { std::for_each(groups_.begin(), groups_.end(), [](EagerGroup &group) {
group.pending_ = group.tensor_indices_.size(); group.pending_ = group.tensor_indices_.size();
...@@ -648,9 +660,9 @@ void EagerReducer::AddDistHook(size_t var_index) { ...@@ -648,9 +660,9 @@ void EagerReducer::AddDistHook(size_t var_index) {
var_index)); var_index));
// gradient synchronization is not required when grad_need_hooks_ is false. // gradient synchronization is not required when grad_need_hooks_ is false.
if (!grad_need_hooks_) { // if (!grad_need_hooks_) {
return; // return;
} // }
VLOG(3) << "Tensor[" << var_index << "] [" << tensors_[var_index].name() VLOG(3) << "Tensor[" << var_index << "] [" << tensors_[var_index].name()
<< "@Grad] arrived and triggered disthook"; << "@Grad] arrived and triggered disthook";
...@@ -816,10 +828,12 @@ void EagerReducer::MarkGroupReady(size_t group_index) { ...@@ -816,10 +828,12 @@ void EagerReducer::MarkGroupReady(size_t group_index) {
for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0; for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0;
++next_group_) { ++next_group_) {
UNUSED auto &group = groups_[next_group_]; UNUSED auto &group = groups_[next_group_];
if (group.is_sparse_) { if (grad_need_hooks_) {
AllReduceSparse(&group, next_group_); if (group.is_sparse_) {
} else { AllReduceSparse(&group, next_group_);
FusedAllReduceSchedule(&group, next_group_); } else {
FusedAllReduceSchedule(&group, next_group_);
}
} }
} }
} }
...@@ -907,16 +921,14 @@ void EagerReducer::ProcessUnusedDenseVars() { ...@@ -907,16 +921,14 @@ void EagerReducer::ProcessUnusedDenseVars() {
void EagerReducer::FinalizeBackward() { void EagerReducer::FinalizeBackward() {
groups_need_finalize_ = false; groups_need_finalize_ = false;
grad_need_hooks_ = false;
for (auto &group : groups_) { for (auto &group : groups_) {
if (!group.is_sparse_) { if (!group.is_sparse_ && grad_need_hooks_) {
group.task->Synchronize(); group.task->Synchronize();
} }
} }
for (auto &group : groups_) { for (auto &group : groups_) {
if (!group.is_sparse_) { if (!group.is_sparse_ && grad_need_hooks_) {
group.SplitTensors(inner_place_);
group.dense_contents_.reset(); group.dense_contents_.reset();
} }
} }
...@@ -928,6 +940,7 @@ void EagerReducer::FinalizeBackward() { ...@@ -928,6 +940,7 @@ void EagerReducer::FinalizeBackward() {
VLOG(3) << "ProcessUnusedDenseVars is finished."; VLOG(3) << "ProcessUnusedDenseVars is finished.";
} }
grad_need_hooks_ = false;
VLOG(3) << "In the batch, Reducer is finished."; VLOG(3) << "In the batch, Reducer is finished.";
} }
...@@ -954,6 +967,9 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group, ...@@ -954,6 +967,9 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group,
} }
group->task = process_group_->AllReduce(in_out, in_out, opts); group->task = process_group_->AllReduce(in_out, in_out, opts);
const auto &context = process_group_->GetDeviceContext(inner_place_);
group->SplitTensorsDev(context);
group->task->UpdateWaitChain(context);
// split in FinalizeBackward() // split in FinalizeBackward()
} }
......
...@@ -74,7 +74,8 @@ class EagerGroup { ...@@ -74,7 +74,8 @@ class EagerGroup {
void ConcatTensors(const platform::Place &); void ConcatTensors(const platform::Place &);
// context is used to select the stream for split // context is used to select the stream for split
void SplitTensors(const platform::Place &);
void SplitTensorsDev(const platform::DeviceContext &);
friend std::ostream &operator<<(std::ostream &, const EagerGroup &); friend std::ostream &operator<<(std::ostream &, const EagerGroup &);
}; };
...@@ -102,7 +103,8 @@ class EagerReducer { ...@@ -102,7 +103,8 @@ class EagerReducer {
void InitializeGroups(const std::vector<std::vector<size_t>> &group_indices); void InitializeGroups(const std::vector<std::vector<size_t>> &group_indices);
void InitializeDenseGroups(const std::vector<size_t> &tensor_indices_, void InitializeDenseGroups(const std::vector<size_t> &tensor_indices_,
EagerGroup *p_group); EagerGroup *p_group);
void PrepareForBackward(const std::vector<Tensor> &outputs); void PrepareForBackward(const std::vector<Tensor> &outputs,
const bool is_sync);
void AddDistHook(size_t var_index); void AddDistHook(size_t var_index);
void MarkVarReady(const size_t var_index, const bool is_used_var); void MarkVarReady(const size_t var_index, const bool is_used_var);
void MarkGroupReady(const size_t group_index); void MarkGroupReady(const size_t group_index);
......
...@@ -675,9 +675,10 @@ void Reducer::TraverseBackwardGraph( ...@@ -675,9 +675,10 @@ void Reducer::TraverseBackwardGraph(
// After each batch is calculated, the counter of each group(group.pending_) // After each batch is calculated, the counter of each group(group.pending_)
// and allreudce sequence counter(next_group_) will be cleaned up again. // and allreudce sequence counter(next_group_) will be cleaned up again.
void Reducer::PrepareForBackward( void Reducer::PrepareForBackward(
const std::vector<std::shared_ptr<imperative::VarBase>> &outputs) { const std::vector<std::shared_ptr<imperative::VarBase>> &outputs,
const bool is_sync) {
VLOG(3) << "after forward, then reset count for backward."; VLOG(3) << "after forward, then reset count for backward.";
grad_need_hooks_ = true; grad_need_hooks_ = is_sync;
next_group_ = 0; next_group_ = 0;
std::for_each(groups_.begin(), groups_.end(), [](Group &group) { std::for_each(groups_.begin(), groups_.end(), [](Group &group) {
group.pending_ = group.variable_indices_.size(); group.pending_ = group.variable_indices_.size();
...@@ -710,7 +711,9 @@ void Reducer::PrepareForBackward( ...@@ -710,7 +711,9 @@ void Reducer::PrepareForBackward(
if (find_unused_vars_once_ || find_unused_vars_each_step_) { if (find_unused_vars_once_ || find_unused_vars_each_step_) {
unused_vars_.clear(); unused_vars_.clear();
TraverseBackwardGraph(outputs); if (grad_need_hooks_) {
TraverseBackwardGraph(outputs);
}
// only check once in first step // only check once in first step
find_unused_vars_once_ = false; find_unused_vars_once_ = false;
} }
......
...@@ -146,7 +146,8 @@ class Reducer { ...@@ -146,7 +146,8 @@ class Reducer {
void PrepareDeps(const std::unordered_set<GradOpNode*>& init_nodes); void PrepareDeps(const std::unordered_set<GradOpNode*>& init_nodes);
void PrepareForBackward( void PrepareForBackward(
const std::vector<std::shared_ptr<imperative::VarBase>>& outputs); const std::vector<std::shared_ptr<imperative::VarBase>>& outputs,
const bool is_sync);
void AddDistHook(size_t var_index); void AddDistHook(size_t var_index);
......
...@@ -395,9 +395,10 @@ void BindDistributed(py::module *m) { ...@@ -395,9 +395,10 @@ void BindDistributed(py::module *m) {
concat_out_tensor.impl()); concat_out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
const auto *dev_ctx = self.GetDeviceContext(in_tensor.place()); const auto &dev_ctx = self.GetDeviceContext(in_tensor.place());
auto task = self.AllGather(in_wrapper, out_wrapper, sync_op); auto task = self.AllGather(in_wrapper, out_wrapper, sync_op);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(dev_ctx);
return task; return task;
}, },
py::arg("in"), py::arg("in"),
...@@ -495,10 +496,11 @@ void BindDistributed(py::module *m) { ...@@ -495,10 +496,11 @@ void BindDistributed(py::module *m) {
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
// in_tensor_list should not be empty // in_tensor_list should not be empty
const auto *dev_ctx = const auto &dev_ctx =
self.GetDeviceContext(in_tensor_list.back().place()); self.GetDeviceContext(in_tensor_list.back().place());
auto task = self.AllToAll(in_wrapper, out_wrapper, sync_op); auto task = self.AllToAll(in_wrapper, out_wrapper, sync_op);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(dev_ctx);
return task; return task;
}, },
py::arg("in"), py::arg("in"),
...@@ -796,7 +798,7 @@ void BindDistributed(py::module *m) { ...@@ -796,7 +798,7 @@ void BindDistributed(py::module *m) {
concat_out_tensor.impl()); concat_out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
const auto *dev_ctx = const auto &dev_ctx =
self.GetDeviceContext(in_tensor.place(), true); self.GetDeviceContext(in_tensor.place(), true);
auto task = self.AllGather(in_wrapper, auto task = self.AllGather(in_wrapper,
out_wrapper, out_wrapper,
...@@ -905,7 +907,7 @@ void BindDistributed(py::module *m) { ...@@ -905,7 +907,7 @@ void BindDistributed(py::module *m) {
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
// in_tensor_list must not be empty // in_tensor_list must not be empty
const auto *dev_ctx = self.GetDeviceContext( const auto &dev_ctx = self.GetDeviceContext(
in_tensor_list.back().place(), /*use_calc_stream*/ true); in_tensor_list.back().place(), /*use_calc_stream*/ true);
auto task = self.AllToAll(in_wrapper, auto task = self.AllToAll(in_wrapper,
out_wrapper, out_wrapper,
...@@ -1405,11 +1407,14 @@ void BindDistributed(py::module *m) { ...@@ -1405,11 +1407,14 @@ void BindDistributed(py::module *m) {
.def(py::init(&CreateEagerReducer)) .def(py::init(&CreateEagerReducer))
.def( .def(
"prepare_for_backward", "prepare_for_backward",
[](distributed::EagerReducer &self, py::handle py_tensors) { [](distributed::EagerReducer &self,
py::handle py_tensors,
bool is_sync) {
auto params = CastPyArg2VectorOfTensor(py_tensors.ptr(), 0); auto params = CastPyArg2VectorOfTensor(py_tensors.ptr(), 0);
self.PrepareForBackward(params); self.PrepareForBackward(params, is_sync);
}, },
py::arg("tensors"), py::arg("tensors"),
py::arg("is_sync"),
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
} }
......
...@@ -2569,6 +2569,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -2569,6 +2569,7 @@ void BindImperative(py::module *m_ptr) {
.def("prepare_for_backward", .def("prepare_for_backward",
&imperative::Reducer::PrepareForBackward, &imperative::Reducer::PrepareForBackward,
py::arg("vars"), py::arg("vars"),
py::arg("is_sync"),
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("assign_group_by_size", m.def("assign_group_by_size",
......
...@@ -818,13 +818,9 @@ class DataParallel(layers.Layer): ...@@ -818,13 +818,9 @@ class DataParallel(layers.Layer):
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
outputs = self._layers(*inputs, **kwargs) outputs = self._layers(*inputs, **kwargs)
if ( if self._strategy.nranks > 1 and framework._dygraph_tracer()._has_grad:
self._strategy.nranks > 1
and framework._dygraph_tracer()._has_grad
and self.grad_need_sync
):
self._reducer.prepare_for_backward( self._reducer.prepare_for_backward(
list(self._find_varbase(outputs)) list(self._find_varbase(outputs)), self.grad_need_sync
) )
return outputs return outputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册