From 47388020a2e8e702191369f578fd558fe338d723 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 29 Jun 2018 03:42:18 +0000 Subject: [PATCH] fix bugs --- .../framework/details/data_balance_op_handle.cc | 15 +++++++++++++++ .../framework/details/data_balance_op_handle.h | 11 ++++++++++- .../details/multi_devices_graph_builder.cc | 5 +++++ paddle/fluid/framework/details/op_handle_base.cc | 1 + 4 files changed, 31 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/details/data_balance_op_handle.cc b/paddle/fluid/framework/details/data_balance_op_handle.cc index 24a68506e88..023e0cdf917 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.cc +++ b/paddle/fluid/framework/details/data_balance_op_handle.cc @@ -20,10 +20,24 @@ namespace paddle { namespace framework { namespace details { +#ifdef PADDLE_WITH_CUDA +DataBalanceOpHandle::DataBalanceOpHandle( + const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap *ctxs) + : local_scopes_(local_scopes), places_(places) { + if (ctxs) { + for (auto &p : places_) { + this->dev_ctxes_[p] = ctxs->DevCtx(p); + } + } +} +#else DataBalanceOpHandle::DataBalanceOpHandle( const std::vector &local_scopes, const std::vector &places) : local_scopes_(local_scopes), places_(places) {} +#endif std::string DataBalanceOpHandle::Name() const { return "data balance"; } @@ -104,6 +118,7 @@ void DataBalanceOpHandle::RunImpl() { } } const auto &balance_plan = GetBalancePlan(device_sizes); + for (const auto &trans : balance_plan) { for (int data_idx = 0; data_idx < data_num; ++data_idx) { LoDTensor *src_tensor = lod_tensors[data_idx][trans[0]]; diff --git a/paddle/fluid/framework/details/data_balance_op_handle.h b/paddle/fluid/framework/details/data_balance_op_handle.h index 5552be2e6eb..a4adafdfeb1 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.h +++ b/paddle/fluid/framework/details/data_balance_op_handle.h @@ -19,6 +19,9 @@ #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/nccl_helper.h" +#endif namespace paddle { namespace framework { @@ -26,8 +29,14 @@ namespace details { struct DataBalanceOpHandle : public OpHandleBase { public: +#ifdef PADDLE_WITH_CUDA DataBalanceOpHandle(const std::vector &local_scopes, - const std::vector &places); + const std::vector &places, + const platform::NCCLContextMap *ctxs); +#else + DataBalanceOpHandle(const std::vector &local_scopes, + const std::vector *places) +#endif std::string Name() const override; diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 4ddc1f2dddc..8a9f0b10545 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -368,7 +368,12 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, void MultiDevSSAGraphBuilder::InsertDataBalanceOp( SSAGraph *result, const std::vector &datas) const { +#ifdef PADDLE_WITH_CUDA + result->ops_.emplace_back( + new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); +#else result->ops_.emplace_back(new DataBalanceOpHandle(local_scopes_, places_)); +#endif auto *op_handle = result->ops_.back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index 1f84c3b9e2d..856124875d5 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -60,6 +60,7 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { #ifdef PADDLE_WITH_CUDA if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) { for (auto &dev_ctx : dev_ctxes_) { + PADDLE_ENFORCE_NOT_NULL(dev_ctx.second); dev_ctx.second->Wait(); } } else { -- GitLab