From f720e231e5e646b0d88c6e1c5ebf9a2ab010e591 Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Thu, 30 Jun 2022 14:45:45 +0800 Subject: [PATCH] Remove boost::variant for FetchResultType (#43932) * Remove boost::variant for FetchResultType * Fix pybind errors --- .../details/async_ssa_graph_executor.cc | 2 +- .../details/fetch_async_op_handle.cc | 4 +- .../framework/details/fetch_op_handle.cc | 6 +- .../details/parallel_ssa_graph_executor.cc | 5 +- paddle/fluid/framework/feed_fetch_type.h | 2 +- paddle/fluid/framework/parallel_executor.cc | 86 ++++++++++++------- paddle/fluid/framework/parallel_executor.h | 6 +- paddle/fluid/pybind/pybind.cc | 40 +++++---- 8 files changed, 90 insertions(+), 61 deletions(-) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index f22e62fa0a..0ae6969554 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -174,7 +174,7 @@ FetchResultType AsyncSSAGraphExecutor::Run( HandleException(); FetchList ret; - auto &val = boost::get(fetch_data); + auto &val = BOOST_GET(FetchList, fetch_data); for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) { if (data_is_lod_tensor(val.at(fetch_idx))) { std::vector lodtensor_ptrs; diff --git a/paddle/fluid/framework/details/fetch_async_op_handle.cc b/paddle/fluid/framework/details/fetch_async_op_handle.cc index 8d8bb96f5c..a9e4bf826b 100644 --- a/paddle/fluid/framework/details/fetch_async_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_async_op_handle.cc @@ -228,7 +228,7 @@ void FetchAsyncOpHandle::RunImpl() { } if (return_merged_) { - auto &val = boost::get(*data_); + auto &val = BOOST_GET(FetchList, *data_); if (src_vars[0]->IsType()) { // to lodtensor type std::vector src_lodtensors; @@ -263,7 +263,7 @@ void FetchAsyncOpHandle::RunImpl() { val.at(offset_) = std::move(dst_lodtensor_array); } } else { - auto &val = boost::get(*data_); + auto &val = BOOST_GET(FetchUnmergedList, *data_); auto &dst_tensors = val.at(offset_); dst_tensors.reserve(src_vars.size()); diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index f160650f0b..a9f7de8ee3 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -84,7 +84,7 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const { for (auto &t : tensors_) { tensors_ptr.emplace_back(&BOOST_GET_CONST(LoDTensor, t)); } - auto &val = boost::get(*data_); + auto &val = BOOST_GET(FetchList, *data_); LoDTensor var; MergeLoDTensor(&var, tensors_ptr, platform::CPUPlace()); val.at(offset_) = std::move(var); @@ -106,11 +106,11 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const { tmp_array.emplace_back(); MergeLoDTensor(&(tmp_array.back()), tensors_ptr, platform::CPUPlace()); } - auto &val = boost::get(*data_); + auto &val = BOOST_GET(FetchList, *data_); val.at(offset_) = std::move(tmp_array); } } else { - auto &val = boost::get(*data_); + auto &val = BOOST_GET(FetchUnmergedList, *data_); val.at(offset_) = std::move(tensors_); } } diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index 86536b74a3..bc870c0eaa 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -278,7 +278,8 @@ FetchResultType ParallelSSAGraphExecutor::Run( if (!is_valid[scope_idx]) { continue; } - const auto &fetch_list = boost::get(fetch_data[scope_idx]); + const auto &fetch_list = + BOOST_GET_CONST(FetchList, fetch_data[scope_idx]); if (data_is_lod_tensor(fetch_list[fetch_idx])) { lodtensor_ptrs.push_back( &(BOOST_GET_CONST(LoDTensor, fetch_list[fetch_idx]))); @@ -317,7 +318,7 @@ FetchResultType ParallelSSAGraphExecutor::Run( continue; } const auto &fetch_list = - boost::get(fetch_data[scope_idx]); + BOOST_GET_CONST(FetchUnmergedList, fetch_data[scope_idx]); PADDLE_ENFORCE_EQ( fetch_list[fetch_idx].size(), 1, diff --git a/paddle/fluid/framework/feed_fetch_type.h b/paddle/fluid/framework/feed_fetch_type.h index c86cdc9981..8ecd6a0339 100644 --- a/paddle/fluid/framework/feed_fetch_type.h +++ b/paddle/fluid/framework/feed_fetch_type.h @@ -30,7 +30,7 @@ using FetchType = paddle::variant; using FetchList = std::vector; using FetchUnmergedList = std::vector>; -using FetchResultType = boost::variant; +using FetchResultType = paddle::variant; inline bool data_is_lod_tensor(const FetchType &data) { if (data.type() == typeid(LoDTensor)) { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index fffacc59ba..697cb8cdcf 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -972,37 +972,26 @@ void ParallelExecutor::BCastParamsToDevices( } } -FetchResultType ParallelExecutor::Run( - const std::vector &fetch_tensors, bool return_merged) { - platform::RecordEvent record_run( - "ParallelExecutor::Run", platform::TracerEventType::UserDefined, 1); - VLOG(3) << "enter ParallelExecutor Run"; -#ifdef PADDLE_WITH_CUDA - if (platform::IsCUDAGraphCapturing()) { - PADDLE_ENFORCE_EQ(fetch_tensors.empty(), - true, - platform::errors::InvalidArgument( - "Cannot fetch data when using CUDA Graph.")); - PADDLE_ENFORCE_EQ( - member_->build_strategy_.allow_cuda_graph_capture_, - true, - platform::errors::InvalidArgument( - "You must turn on build_strategy.allow_cuda_graph_capture = True " - "to enable CUDA Graph capturing.")); - PADDLE_ENFORCE_EQ( - member_->places_[0], - platform::CUDAGraphCapturingPlace(), - platform::errors::InvalidArgument("The place to capture CUDAGraph is " - "not the same as the place to run.")); - } -#endif +FetchUnmergedList ParallelExecutor::Run( + const std::vector &fetch_tensors) { + PreludeToRun(fetch_tensors); + platform::RecordBlock b(0); -#ifdef WITH_GPERFTOOLS - if (gProfileStarted) { - ProfilerFlush(); - } -#endif + ResetHasFeedGuard reset_has_feed_guard(member_); + + ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_), + fetch_tensors, + member_->HasGarbageCollectors()); + VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run"; + auto fetch_data = + member_->executor_->Run(fetch_tensors, /*return_merged=*/false); + return BOOST_GET(FetchUnmergedList, fetch_data); +} + +FetchList ParallelExecutor::RunAndMerge( + const std::vector &fetch_tensors) { + PreludeToRun(fetch_tensors); platform::RecordBlock b(0); ResetHasFeedGuard reset_has_feed_guard(member_); @@ -1011,9 +1000,10 @@ FetchResultType ParallelExecutor::Run( fetch_tensors, member_->HasGarbageCollectors()); - VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run"; - auto fetch_data = member_->executor_->Run(fetch_tensors, return_merged); - return fetch_data; + VLOG(3) << "ParallelExecutor begin to run member_->executor_->RunAndMerge"; + auto fetch_data = + member_->executor_->Run(fetch_tensors, /*return_merged=*/true); + return BOOST_GET(FetchList, fetch_data); } void ParallelExecutor::RunWithoutFetch( @@ -1440,6 +1430,38 @@ std::vector ParallelExecutor::CloneGraphToMultiDevices( return graphs; } +void ParallelExecutor::PreludeToRun( + const std::vector &fetch_tensors) { + platform::RecordEvent record_run( + "ParallelExecutor::Run", platform::TracerEventType::UserDefined, 1); + VLOG(3) << "enter ParallelExecutor Run"; +#ifdef PADDLE_WITH_CUDA + if (platform::IsCUDAGraphCapturing()) { + PADDLE_ENFORCE_EQ(fetch_tensors.empty(), + true, + platform::errors::InvalidArgument( + "Cannot fetch data when using CUDA Graph.")); + PADDLE_ENFORCE_EQ( + member_->build_strategy_.allow_cuda_graph_capture_, + true, + platform::errors::InvalidArgument( + "You must turn on build_strategy.allow_cuda_graph_capture = True " + "to enable CUDA Graph capturing.")); + PADDLE_ENFORCE_EQ( + member_->places_[0], + platform::CUDAGraphCapturingPlace(), + platform::errors::InvalidArgument("The place to capture CUDAGraph is " + "not the same as the place to run.")); + } +#endif + +#ifdef WITH_GPERFTOOLS + if (gProfileStarted) { + ProfilerFlush(); + } +#endif +} + void ParallelExecutor::PrepareNCCLCommunicator(Scope *global_scope) { if (member_->build_strategy_.reduce_ == BuildStrategy::ReduceStrategy::kNoReduce) { diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 4cb9c0340b..a3b812a71a 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -89,8 +89,8 @@ class ParallelExecutor { void FeedAndSplitTensorIntoLocalScopes( const std::unordered_map &tensors); - FetchResultType Run(const std::vector &fetch_tensors, - bool return_merged = true); + FetchUnmergedList Run(const std::vector &fetch_tensors); + FetchList RunAndMerge(const std::vector &fetch_tensors); void RunWithoutFetch(const std::vector &skip_eager_vars); @@ -126,6 +126,8 @@ class ParallelExecutor { std::vector CloneGraphToMultiDevices(ir::Graph *graph); + void PreludeToRun(const std::vector &fetch_tensors); + void PrepareNCCLCommunicator(Scope *global_scope); std::vector CompileGraphWithBuildStrategy( diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c294c8eb4a..18a3fb1aab 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -3225,13 +3225,17 @@ All parameter, weight, gradient are variables in Paddle. #endif m.def("set_feed_variable", - static_cast( - &framework::SetFeedVariable)); + static_cast(&framework::SetFeedVariable)); m.def("set_feed_variable", - static_cast( - &framework::SetFeedVariable)); + static_cast(&framework::SetFeedVariable)); m.def("get_fetch_variable", [](const Scope &scope, const std::string &var_name, @@ -4601,20 +4605,20 @@ All parameter, weight, gradient are variables in Paddle. [](ParallelExecutor &self, const std::vector &fetch_tensors, bool return_merged) -> py::object { - paddle::framework::FetchResultType ret; - { - pybind11::gil_scoped_release release; - ret = self.Run(fetch_tensors, return_merged); - } - - // TODO(Ruibiao): Refactor the run interface of PE to avoid use - // boost::get here if (return_merged) { - return py::cast( - std::move(boost::get(ret))); + paddle::framework::FetchList ret; + /*gil_scoped_release*/ { + pybind11::gil_scoped_release release; + ret = self.RunAndMerge(fetch_tensors); + } + return py::cast(std::move(ret)); } else { - return py::cast(std::move( - boost::get(ret))); + paddle::framework::FetchUnmergedList ret; + /*gil_scoped_release*/ { + pybind11::gil_scoped_release release; + ret = self.Run(fetch_tensors); + } + return py::cast(std::move(ret)); } }) .def("device_count", &ParallelExecutor::DeviceCount); -- GitLab