diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index f22e62fa0aa5bd9b1a0445e46022d4fe9c605b88..0ae69695549e529d821b28480c0eec9ab0be532a 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -174,7 +174,7 @@ FetchResultType AsyncSSAGraphExecutor::Run( HandleException(); FetchList ret; - auto &val = boost::get(fetch_data); + auto &val = BOOST_GET(FetchList, fetch_data); for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) { if (data_is_lod_tensor(val.at(fetch_idx))) { std::vector lodtensor_ptrs; diff --git a/paddle/fluid/framework/details/fetch_async_op_handle.cc b/paddle/fluid/framework/details/fetch_async_op_handle.cc index 8d8bb96f5c8edb29ee9ac5295df7a28c98a834a1..a9e4bf826bc4b2ef44fe0e416429ed1b1ceb33f5 100644 --- a/paddle/fluid/framework/details/fetch_async_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_async_op_handle.cc @@ -228,7 +228,7 @@ void FetchAsyncOpHandle::RunImpl() { } if (return_merged_) { - auto &val = boost::get(*data_); + auto &val = BOOST_GET(FetchList, *data_); if (src_vars[0]->IsType()) { // to lodtensor type std::vector src_lodtensors; @@ -263,7 +263,7 @@ void FetchAsyncOpHandle::RunImpl() { val.at(offset_) = std::move(dst_lodtensor_array); } } else { - auto &val = boost::get(*data_); + auto &val = BOOST_GET(FetchUnmergedList, *data_); auto &dst_tensors = val.at(offset_); dst_tensors.reserve(src_vars.size()); diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index f160650f0b9f4c464411cb871d287eeb16fe5ba5..a9f7de8ee312f914057471bc4741c0bf4aefb536 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -84,7 +84,7 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const { for (auto &t : tensors_) { tensors_ptr.emplace_back(&BOOST_GET_CONST(LoDTensor, t)); } - auto &val = boost::get(*data_); + auto &val = BOOST_GET(FetchList, *data_); LoDTensor var; MergeLoDTensor(&var, tensors_ptr, platform::CPUPlace()); val.at(offset_) = std::move(var); @@ -106,11 +106,11 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const { tmp_array.emplace_back(); MergeLoDTensor(&(tmp_array.back()), tensors_ptr, platform::CPUPlace()); } - auto &val = boost::get(*data_); + auto &val = BOOST_GET(FetchList, *data_); val.at(offset_) = std::move(tmp_array); } } else { - auto &val = boost::get(*data_); + auto &val = BOOST_GET(FetchUnmergedList, *data_); val.at(offset_) = std::move(tensors_); } } diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index 86536b74a3d7c9f04c501a72dff433080eed3a42..bc870c0eaa18d932b74a9a78f128aa8c16a1bfbd 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -278,7 +278,8 @@ FetchResultType ParallelSSAGraphExecutor::Run( if (!is_valid[scope_idx]) { continue; } - const auto &fetch_list = boost::get(fetch_data[scope_idx]); + const auto &fetch_list = + BOOST_GET_CONST(FetchList, fetch_data[scope_idx]); if (data_is_lod_tensor(fetch_list[fetch_idx])) { lodtensor_ptrs.push_back( &(BOOST_GET_CONST(LoDTensor, fetch_list[fetch_idx]))); @@ -317,7 +318,7 @@ FetchResultType ParallelSSAGraphExecutor::Run( continue; } const auto &fetch_list = - boost::get(fetch_data[scope_idx]); + BOOST_GET_CONST(FetchUnmergedList, fetch_data[scope_idx]); PADDLE_ENFORCE_EQ( fetch_list[fetch_idx].size(), 1, diff --git a/paddle/fluid/framework/feed_fetch_type.h b/paddle/fluid/framework/feed_fetch_type.h index c86cdc998133b8d674a667c98b90fb18e2e3eff3..8ecd6a0339b5bc21ab6f5f534960cde4ca652212 100644 --- a/paddle/fluid/framework/feed_fetch_type.h +++ b/paddle/fluid/framework/feed_fetch_type.h @@ -30,7 +30,7 @@ using FetchType = paddle::variant; using FetchList = std::vector; using FetchUnmergedList = std::vector>; -using FetchResultType = boost::variant; +using FetchResultType = paddle::variant; inline bool data_is_lod_tensor(const FetchType &data) { if (data.type() == typeid(LoDTensor)) { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index fffacc59ba7bc5c89a0dfb949d2b6c263cdb8ec2..697cb8cdcf6e829db06197d8e5240356630cdb9f 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -972,37 +972,26 @@ void ParallelExecutor::BCastParamsToDevices( } } -FetchResultType ParallelExecutor::Run( - const std::vector &fetch_tensors, bool return_merged) { - platform::RecordEvent record_run( - "ParallelExecutor::Run", platform::TracerEventType::UserDefined, 1); - VLOG(3) << "enter ParallelExecutor Run"; -#ifdef PADDLE_WITH_CUDA - if (platform::IsCUDAGraphCapturing()) { - PADDLE_ENFORCE_EQ(fetch_tensors.empty(), - true, - platform::errors::InvalidArgument( - "Cannot fetch data when using CUDA Graph.")); - PADDLE_ENFORCE_EQ( - member_->build_strategy_.allow_cuda_graph_capture_, - true, - platform::errors::InvalidArgument( - "You must turn on build_strategy.allow_cuda_graph_capture = True " - "to enable CUDA Graph capturing.")); - PADDLE_ENFORCE_EQ( - member_->places_[0], - platform::CUDAGraphCapturingPlace(), - platform::errors::InvalidArgument("The place to capture CUDAGraph is " - "not the same as the place to run.")); - } -#endif +FetchUnmergedList ParallelExecutor::Run( + const std::vector &fetch_tensors) { + PreludeToRun(fetch_tensors); + platform::RecordBlock b(0); -#ifdef WITH_GPERFTOOLS - if (gProfileStarted) { - ProfilerFlush(); - } -#endif + ResetHasFeedGuard reset_has_feed_guard(member_); + + ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_), + fetch_tensors, + member_->HasGarbageCollectors()); + VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run"; + auto fetch_data = + member_->executor_->Run(fetch_tensors, /*return_merged=*/false); + return BOOST_GET(FetchUnmergedList, fetch_data); +} + +FetchList ParallelExecutor::RunAndMerge( + const std::vector &fetch_tensors) { + PreludeToRun(fetch_tensors); platform::RecordBlock b(0); ResetHasFeedGuard reset_has_feed_guard(member_); @@ -1011,9 +1000,10 @@ FetchResultType ParallelExecutor::Run( fetch_tensors, member_->HasGarbageCollectors()); - VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run"; - auto fetch_data = member_->executor_->Run(fetch_tensors, return_merged); - return fetch_data; + VLOG(3) << "ParallelExecutor begin to run member_->executor_->RunAndMerge"; + auto fetch_data = + member_->executor_->Run(fetch_tensors, /*return_merged=*/true); + return BOOST_GET(FetchList, fetch_data); } void ParallelExecutor::RunWithoutFetch( @@ -1440,6 +1430,38 @@ std::vector ParallelExecutor::CloneGraphToMultiDevices( return graphs; } +void ParallelExecutor::PreludeToRun( + const std::vector &fetch_tensors) { + platform::RecordEvent record_run( + "ParallelExecutor::Run", platform::TracerEventType::UserDefined, 1); + VLOG(3) << "enter ParallelExecutor Run"; +#ifdef PADDLE_WITH_CUDA + if (platform::IsCUDAGraphCapturing()) { + PADDLE_ENFORCE_EQ(fetch_tensors.empty(), + true, + platform::errors::InvalidArgument( + "Cannot fetch data when using CUDA Graph.")); + PADDLE_ENFORCE_EQ( + member_->build_strategy_.allow_cuda_graph_capture_, + true, + platform::errors::InvalidArgument( + "You must turn on build_strategy.allow_cuda_graph_capture = True " + "to enable CUDA Graph capturing.")); + PADDLE_ENFORCE_EQ( + member_->places_[0], + platform::CUDAGraphCapturingPlace(), + platform::errors::InvalidArgument("The place to capture CUDAGraph is " + "not the same as the place to run.")); + } +#endif + +#ifdef WITH_GPERFTOOLS + if (gProfileStarted) { + ProfilerFlush(); + } +#endif +} + void ParallelExecutor::PrepareNCCLCommunicator(Scope *global_scope) { if (member_->build_strategy_.reduce_ == BuildStrategy::ReduceStrategy::kNoReduce) { diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 4cb9c0340b53cb48625a92bbb8b0617f06451d93..a3b812a71a2b7e9657517712d76b4887de1285f1 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -89,8 +89,8 @@ class ParallelExecutor { void FeedAndSplitTensorIntoLocalScopes( const std::unordered_map &tensors); - FetchResultType Run(const std::vector &fetch_tensors, - bool return_merged = true); + FetchUnmergedList Run(const std::vector &fetch_tensors); + FetchList RunAndMerge(const std::vector &fetch_tensors); void RunWithoutFetch(const std::vector &skip_eager_vars); @@ -126,6 +126,8 @@ class ParallelExecutor { std::vector CloneGraphToMultiDevices(ir::Graph *graph); + void PreludeToRun(const std::vector &fetch_tensors); + void PrepareNCCLCommunicator(Scope *global_scope); std::vector CompileGraphWithBuildStrategy( diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c294c8eb4a7c9f09a2be04fbcae23c4d3483f819..18a3fb1aab86b8e8562881e41fbf3bac1ac9e9a9 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -3225,13 +3225,17 @@ All parameter, weight, gradient are variables in Paddle. #endif m.def("set_feed_variable", - static_cast( - &framework::SetFeedVariable)); + static_cast(&framework::SetFeedVariable)); m.def("set_feed_variable", - static_cast( - &framework::SetFeedVariable)); + static_cast(&framework::SetFeedVariable)); m.def("get_fetch_variable", [](const Scope &scope, const std::string &var_name, @@ -4601,20 +4605,20 @@ All parameter, weight, gradient are variables in Paddle. [](ParallelExecutor &self, const std::vector &fetch_tensors, bool return_merged) -> py::object { - paddle::framework::FetchResultType ret; - { - pybind11::gil_scoped_release release; - ret = self.Run(fetch_tensors, return_merged); - } - - // TODO(Ruibiao): Refactor the run interface of PE to avoid use - // boost::get here if (return_merged) { - return py::cast( - std::move(boost::get(ret))); + paddle::framework::FetchList ret; + /*gil_scoped_release*/ { + pybind11::gil_scoped_release release; + ret = self.RunAndMerge(fetch_tensors); + } + return py::cast(std::move(ret)); } else { - return py::cast(std::move( - boost::get(ret))); + paddle::framework::FetchUnmergedList ret; + /*gil_scoped_release*/ { + pybind11::gil_scoped_release release; + ret = self.Run(fetch_tensors); + } + return py::cast(std::move(ret)); } }) .def("device_count", &ParallelExecutor::DeviceCount);