diff --git a/paddle/fluid/framework/details/data_balance_op_handle.cc b/paddle/fluid/framework/details/data_balance_op_handle.cc index 24a68506e88fee7523e05f20bc28edeb4f8c2b7b..023e0cdf9175d9f1c33505d8c0b461efdc407061 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.cc +++ b/paddle/fluid/framework/details/data_balance_op_handle.cc @@ -20,10 +20,24 @@ namespace paddle { namespace framework { namespace details { +#ifdef PADDLE_WITH_CUDA +DataBalanceOpHandle::DataBalanceOpHandle( + const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap *ctxs) + : local_scopes_(local_scopes), places_(places) { + if (ctxs) { + for (auto &p : places_) { + this->dev_ctxes_[p] = ctxs->DevCtx(p); + } + } +} +#else DataBalanceOpHandle::DataBalanceOpHandle( const std::vector &local_scopes, const std::vector &places) : local_scopes_(local_scopes), places_(places) {} +#endif std::string DataBalanceOpHandle::Name() const { return "data balance"; } @@ -104,6 +118,7 @@ void DataBalanceOpHandle::RunImpl() { } } const auto &balance_plan = GetBalancePlan(device_sizes); + for (const auto &trans : balance_plan) { for (int data_idx = 0; data_idx < data_num; ++data_idx) { LoDTensor *src_tensor = lod_tensors[data_idx][trans[0]]; diff --git a/paddle/fluid/framework/details/data_balance_op_handle.h b/paddle/fluid/framework/details/data_balance_op_handle.h index 5552be2e6eb2f67eb56022a37fcfb0c27c3df073..a4adafdfeb19fa9ec57c5b3c5313887e8361d378 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.h +++ b/paddle/fluid/framework/details/data_balance_op_handle.h @@ -19,6 +19,9 @@ #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/nccl_helper.h" +#endif namespace paddle { namespace framework { @@ -26,8 +29,14 @@ namespace details { struct DataBalanceOpHandle : public OpHandleBase { public: +#ifdef PADDLE_WITH_CUDA DataBalanceOpHandle(const std::vector &local_scopes, - const std::vector &places); + const std::vector &places, + const platform::NCCLContextMap *ctxs); +#else + DataBalanceOpHandle(const std::vector &local_scopes, + const std::vector *places) +#endif std::string Name() const override; diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 4ddc1f2dddc91bb832a2b9fee282bf9964fb7921..8a9f0b1054575b41c645d81ab70f5f2a37fd8845 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -368,7 +368,12 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, void MultiDevSSAGraphBuilder::InsertDataBalanceOp( SSAGraph *result, const std::vector &datas) const { +#ifdef PADDLE_WITH_CUDA + result->ops_.emplace_back( + new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); +#else result->ops_.emplace_back(new DataBalanceOpHandle(local_scopes_, places_)); +#endif auto *op_handle = result->ops_.back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index 1f84c3b9e2d7ee9ae51959988fceeb3451b7b3b8..856124875d55e65428a7fb23e402c0d311900724 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -60,6 +60,7 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { #ifdef PADDLE_WITH_CUDA if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) { for (auto &dev_ctx : dev_ctxes_) { + PADDLE_ENFORCE_NOT_NULL(dev_ctx.second); dev_ctx.second->Wait(); } } else {