未验证 提交 f720e231 编写于 作者: R Ruibiao Chen 提交者: GitHub

Remove boost::variant for FetchResultType (#43932)

* Remove boost::variant for FetchResultType

* Fix pybind errors
上级 6467ca0d
...@@ -174,7 +174,7 @@ FetchResultType AsyncSSAGraphExecutor::Run( ...@@ -174,7 +174,7 @@ FetchResultType AsyncSSAGraphExecutor::Run(
HandleException(); HandleException();
FetchList ret; FetchList ret;
auto &val = boost::get<FetchList>(fetch_data); auto &val = BOOST_GET(FetchList, fetch_data);
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) { for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
if (data_is_lod_tensor(val.at(fetch_idx))) { if (data_is_lod_tensor(val.at(fetch_idx))) {
std::vector<const LoDTensor *> lodtensor_ptrs; std::vector<const LoDTensor *> lodtensor_ptrs;
......
...@@ -228,7 +228,7 @@ void FetchAsyncOpHandle::RunImpl() { ...@@ -228,7 +228,7 @@ void FetchAsyncOpHandle::RunImpl() {
} }
if (return_merged_) { if (return_merged_) {
auto &val = boost::get<FetchList>(*data_); auto &val = BOOST_GET(FetchList, *data_);
if (src_vars[0]->IsType<LoDTensor>()) { if (src_vars[0]->IsType<LoDTensor>()) {
// to lodtensor type // to lodtensor type
std::vector<const LoDTensor *> src_lodtensors; std::vector<const LoDTensor *> src_lodtensors;
...@@ -263,7 +263,7 @@ void FetchAsyncOpHandle::RunImpl() { ...@@ -263,7 +263,7 @@ void FetchAsyncOpHandle::RunImpl() {
val.at(offset_) = std::move(dst_lodtensor_array); val.at(offset_) = std::move(dst_lodtensor_array);
} }
} else { } else {
auto &val = boost::get<FetchUnmergedList>(*data_); auto &val = BOOST_GET(FetchUnmergedList, *data_);
auto &dst_tensors = val.at(offset_); auto &dst_tensors = val.at(offset_);
dst_tensors.reserve(src_vars.size()); dst_tensors.reserve(src_vars.size());
......
...@@ -84,7 +84,7 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const { ...@@ -84,7 +84,7 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const {
for (auto &t : tensors_) { for (auto &t : tensors_) {
tensors_ptr.emplace_back(&BOOST_GET_CONST(LoDTensor, t)); tensors_ptr.emplace_back(&BOOST_GET_CONST(LoDTensor, t));
} }
auto &val = boost::get<FetchList>(*data_); auto &val = BOOST_GET(FetchList, *data_);
LoDTensor var; LoDTensor var;
MergeLoDTensor(&var, tensors_ptr, platform::CPUPlace()); MergeLoDTensor(&var, tensors_ptr, platform::CPUPlace());
val.at(offset_) = std::move(var); val.at(offset_) = std::move(var);
...@@ -106,11 +106,11 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const { ...@@ -106,11 +106,11 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const {
tmp_array.emplace_back(); tmp_array.emplace_back();
MergeLoDTensor(&(tmp_array.back()), tensors_ptr, platform::CPUPlace()); MergeLoDTensor(&(tmp_array.back()), tensors_ptr, platform::CPUPlace());
} }
auto &val = boost::get<FetchList>(*data_); auto &val = BOOST_GET(FetchList, *data_);
val.at(offset_) = std::move(tmp_array); val.at(offset_) = std::move(tmp_array);
} }
} else { } else {
auto &val = boost::get<FetchUnmergedList>(*data_); auto &val = BOOST_GET(FetchUnmergedList, *data_);
val.at(offset_) = std::move(tensors_); val.at(offset_) = std::move(tensors_);
} }
} }
......
...@@ -278,7 +278,8 @@ FetchResultType ParallelSSAGraphExecutor::Run( ...@@ -278,7 +278,8 @@ FetchResultType ParallelSSAGraphExecutor::Run(
if (!is_valid[scope_idx]) { if (!is_valid[scope_idx]) {
continue; continue;
} }
const auto &fetch_list = boost::get<FetchList>(fetch_data[scope_idx]); const auto &fetch_list =
BOOST_GET_CONST(FetchList, fetch_data[scope_idx]);
if (data_is_lod_tensor(fetch_list[fetch_idx])) { if (data_is_lod_tensor(fetch_list[fetch_idx])) {
lodtensor_ptrs.push_back( lodtensor_ptrs.push_back(
&(BOOST_GET_CONST(LoDTensor, fetch_list[fetch_idx]))); &(BOOST_GET_CONST(LoDTensor, fetch_list[fetch_idx])));
...@@ -317,7 +318,7 @@ FetchResultType ParallelSSAGraphExecutor::Run( ...@@ -317,7 +318,7 @@ FetchResultType ParallelSSAGraphExecutor::Run(
continue; continue;
} }
const auto &fetch_list = const auto &fetch_list =
boost::get<FetchUnmergedList>(fetch_data[scope_idx]); BOOST_GET_CONST(FetchUnmergedList, fetch_data[scope_idx]);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
fetch_list[fetch_idx].size(), fetch_list[fetch_idx].size(),
1, 1,
......
...@@ -30,7 +30,7 @@ using FetchType = paddle::variant<LoDTensor, LoDTensorArray, framework::Vocab>; ...@@ -30,7 +30,7 @@ using FetchType = paddle::variant<LoDTensor, LoDTensorArray, framework::Vocab>;
using FetchList = std::vector<FetchType>; using FetchList = std::vector<FetchType>;
using FetchUnmergedList = std::vector<std::vector<FetchType>>; using FetchUnmergedList = std::vector<std::vector<FetchType>>;
using FetchResultType = boost::variant<FetchList, FetchUnmergedList>; using FetchResultType = paddle::variant<FetchList, FetchUnmergedList>;
inline bool data_is_lod_tensor(const FetchType &data) { inline bool data_is_lod_tensor(const FetchType &data) {
if (data.type() == typeid(LoDTensor)) { if (data.type() == typeid(LoDTensor)) {
......
...@@ -972,37 +972,26 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -972,37 +972,26 @@ void ParallelExecutor::BCastParamsToDevices(
} }
} }
FetchResultType ParallelExecutor::Run( FetchUnmergedList ParallelExecutor::Run(
const std::vector<std::string> &fetch_tensors, bool return_merged) { const std::vector<std::string> &fetch_tensors) {
platform::RecordEvent record_run( PreludeToRun(fetch_tensors);
"ParallelExecutor::Run", platform::TracerEventType::UserDefined, 1); platform::RecordBlock b(0);
VLOG(3) << "enter ParallelExecutor Run";
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) {
PADDLE_ENFORCE_EQ(fetch_tensors.empty(),
true,
platform::errors::InvalidArgument(
"Cannot fetch data when using CUDA Graph."));
PADDLE_ENFORCE_EQ(
member_->build_strategy_.allow_cuda_graph_capture_,
true,
platform::errors::InvalidArgument(
"You must turn on build_strategy.allow_cuda_graph_capture = True "
"to enable CUDA Graph capturing."));
PADDLE_ENFORCE_EQ(
member_->places_[0],
platform::CUDAGraphCapturingPlace(),
platform::errors::InvalidArgument("The place to capture CUDAGraph is "
"not the same as the place to run."));
}
#endif
#ifdef WITH_GPERFTOOLS ResetHasFeedGuard reset_has_feed_guard(member_);
if (gProfileStarted) {
ProfilerFlush(); ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_),
} fetch_tensors,
#endif member_->HasGarbageCollectors());
VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run";
auto fetch_data =
member_->executor_->Run(fetch_tensors, /*return_merged=*/false);
return BOOST_GET(FetchUnmergedList, fetch_data);
}
FetchList ParallelExecutor::RunAndMerge(
const std::vector<std::string> &fetch_tensors) {
PreludeToRun(fetch_tensors);
platform::RecordBlock b(0); platform::RecordBlock b(0);
ResetHasFeedGuard reset_has_feed_guard(member_); ResetHasFeedGuard reset_has_feed_guard(member_);
...@@ -1011,9 +1000,10 @@ FetchResultType ParallelExecutor::Run( ...@@ -1011,9 +1000,10 @@ FetchResultType ParallelExecutor::Run(
fetch_tensors, fetch_tensors,
member_->HasGarbageCollectors()); member_->HasGarbageCollectors());
VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run"; VLOG(3) << "ParallelExecutor begin to run member_->executor_->RunAndMerge";
auto fetch_data = member_->executor_->Run(fetch_tensors, return_merged); auto fetch_data =
return fetch_data; member_->executor_->Run(fetch_tensors, /*return_merged=*/true);
return BOOST_GET(FetchList, fetch_data);
} }
void ParallelExecutor::RunWithoutFetch( void ParallelExecutor::RunWithoutFetch(
...@@ -1440,6 +1430,38 @@ std::vector<ir::Graph *> ParallelExecutor::CloneGraphToMultiDevices( ...@@ -1440,6 +1430,38 @@ std::vector<ir::Graph *> ParallelExecutor::CloneGraphToMultiDevices(
return graphs; return graphs;
} }
void ParallelExecutor::PreludeToRun(
const std::vector<std::string> &fetch_tensors) {
platform::RecordEvent record_run(
"ParallelExecutor::Run", platform::TracerEventType::UserDefined, 1);
VLOG(3) << "enter ParallelExecutor Run";
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) {
PADDLE_ENFORCE_EQ(fetch_tensors.empty(),
true,
platform::errors::InvalidArgument(
"Cannot fetch data when using CUDA Graph."));
PADDLE_ENFORCE_EQ(
member_->build_strategy_.allow_cuda_graph_capture_,
true,
platform::errors::InvalidArgument(
"You must turn on build_strategy.allow_cuda_graph_capture = True "
"to enable CUDA Graph capturing."));
PADDLE_ENFORCE_EQ(
member_->places_[0],
platform::CUDAGraphCapturingPlace(),
platform::errors::InvalidArgument("The place to capture CUDAGraph is "
"not the same as the place to run."));
}
#endif
#ifdef WITH_GPERFTOOLS
if (gProfileStarted) {
ProfilerFlush();
}
#endif
}
void ParallelExecutor::PrepareNCCLCommunicator(Scope *global_scope) { void ParallelExecutor::PrepareNCCLCommunicator(Scope *global_scope) {
if (member_->build_strategy_.reduce_ == if (member_->build_strategy_.reduce_ ==
BuildStrategy::ReduceStrategy::kNoReduce) { BuildStrategy::ReduceStrategy::kNoReduce) {
......
...@@ -89,8 +89,8 @@ class ParallelExecutor { ...@@ -89,8 +89,8 @@ class ParallelExecutor {
void FeedAndSplitTensorIntoLocalScopes( void FeedAndSplitTensorIntoLocalScopes(
const std::unordered_map<std::string, LoDTensor> &tensors); const std::unordered_map<std::string, LoDTensor> &tensors);
FetchResultType Run(const std::vector<std::string> &fetch_tensors, FetchUnmergedList Run(const std::vector<std::string> &fetch_tensors);
bool return_merged = true); FetchList RunAndMerge(const std::vector<std::string> &fetch_tensors);
void RunWithoutFetch(const std::vector<std::string> &skip_eager_vars); void RunWithoutFetch(const std::vector<std::string> &skip_eager_vars);
...@@ -126,6 +126,8 @@ class ParallelExecutor { ...@@ -126,6 +126,8 @@ class ParallelExecutor {
std::vector<ir::Graph *> CloneGraphToMultiDevices(ir::Graph *graph); std::vector<ir::Graph *> CloneGraphToMultiDevices(ir::Graph *graph);
void PreludeToRun(const std::vector<std::string> &fetch_tensors);
void PrepareNCCLCommunicator(Scope *global_scope); void PrepareNCCLCommunicator(Scope *global_scope);
std::vector<ir::Graph *> CompileGraphWithBuildStrategy( std::vector<ir::Graph *> CompileGraphWithBuildStrategy(
......
...@@ -3225,13 +3225,17 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -3225,13 +3225,17 @@ All parameter, weight, gradient are variables in Paddle.
#endif #endif
m.def("set_feed_variable", m.def("set_feed_variable",
static_cast<void (*)( static_cast<void (*)( // NOLINT
Scope *, const LoDTensor &, const std::string &, size_t)>( Scope *,
&framework::SetFeedVariable)); const LoDTensor &,
const std::string &,
size_t)>(&framework::SetFeedVariable));
m.def("set_feed_variable", m.def("set_feed_variable",
static_cast<void (*)( static_cast<void (*)( // NOLINT
Scope *, const Strings &, const std::string &, size_t)>( Scope *,
&framework::SetFeedVariable)); const Strings &,
const std::string &,
size_t)>(&framework::SetFeedVariable));
m.def("get_fetch_variable", m.def("get_fetch_variable",
[](const Scope &scope, [](const Scope &scope,
const std::string &var_name, const std::string &var_name,
...@@ -4601,20 +4605,20 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -4601,20 +4605,20 @@ All parameter, weight, gradient are variables in Paddle.
[](ParallelExecutor &self, [](ParallelExecutor &self,
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
bool return_merged) -> py::object { bool return_merged) -> py::object {
paddle::framework::FetchResultType ret;
{
pybind11::gil_scoped_release release;
ret = self.Run(fetch_tensors, return_merged);
}
// TODO(Ruibiao): Refactor the run interface of PE to avoid use
// boost::get here
if (return_merged) { if (return_merged) {
return py::cast( paddle::framework::FetchList ret;
std::move(boost::get<paddle::framework::FetchList>(ret))); /*gil_scoped_release*/ {
pybind11::gil_scoped_release release;
ret = self.RunAndMerge(fetch_tensors);
}
return py::cast(std::move(ret));
} else { } else {
return py::cast(std::move( paddle::framework::FetchUnmergedList ret;
boost::get<paddle::framework::FetchUnmergedList>(ret))); /*gil_scoped_release*/ {
pybind11::gil_scoped_release release;
ret = self.Run(fetch_tensors);
}
return py::cast(std::move(ret));
} }
}) })
.def("device_count", &ParallelExecutor::DeviceCount); .def("device_count", &ParallelExecutor::DeviceCount);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册