未验证 提交 f720e231 编写于 作者: R Ruibiao Chen 提交者: GitHub

Remove boost::variant for FetchResultType (#43932)

* Remove boost::variant for FetchResultType

* Fix pybind errors
上级 6467ca0d
......@@ -174,7 +174,7 @@ FetchResultType AsyncSSAGraphExecutor::Run(
HandleException();
FetchList ret;
auto &val = boost::get<FetchList>(fetch_data);
auto &val = BOOST_GET(FetchList, fetch_data);
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
if (data_is_lod_tensor(val.at(fetch_idx))) {
std::vector<const LoDTensor *> lodtensor_ptrs;
......
......@@ -228,7 +228,7 @@ void FetchAsyncOpHandle::RunImpl() {
}
if (return_merged_) {
auto &val = boost::get<FetchList>(*data_);
auto &val = BOOST_GET(FetchList, *data_);
if (src_vars[0]->IsType<LoDTensor>()) {
// to lodtensor type
std::vector<const LoDTensor *> src_lodtensors;
......@@ -263,7 +263,7 @@ void FetchAsyncOpHandle::RunImpl() {
val.at(offset_) = std::move(dst_lodtensor_array);
}
} else {
auto &val = boost::get<FetchUnmergedList>(*data_);
auto &val = BOOST_GET(FetchUnmergedList, *data_);
auto &dst_tensors = val.at(offset_);
dst_tensors.reserve(src_vars.size());
......
......@@ -84,7 +84,7 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const {
for (auto &t : tensors_) {
tensors_ptr.emplace_back(&BOOST_GET_CONST(LoDTensor, t));
}
auto &val = boost::get<FetchList>(*data_);
auto &val = BOOST_GET(FetchList, *data_);
LoDTensor var;
MergeLoDTensor(&var, tensors_ptr, platform::CPUPlace());
val.at(offset_) = std::move(var);
......@@ -106,11 +106,11 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const {
tmp_array.emplace_back();
MergeLoDTensor(&(tmp_array.back()), tensors_ptr, platform::CPUPlace());
}
auto &val = boost::get<FetchList>(*data_);
auto &val = BOOST_GET(FetchList, *data_);
val.at(offset_) = std::move(tmp_array);
}
} else {
auto &val = boost::get<FetchUnmergedList>(*data_);
auto &val = BOOST_GET(FetchUnmergedList, *data_);
val.at(offset_) = std::move(tensors_);
}
}
......
......@@ -278,7 +278,8 @@ FetchResultType ParallelSSAGraphExecutor::Run(
if (!is_valid[scope_idx]) {
continue;
}
const auto &fetch_list = boost::get<FetchList>(fetch_data[scope_idx]);
const auto &fetch_list =
BOOST_GET_CONST(FetchList, fetch_data[scope_idx]);
if (data_is_lod_tensor(fetch_list[fetch_idx])) {
lodtensor_ptrs.push_back(
&(BOOST_GET_CONST(LoDTensor, fetch_list[fetch_idx])));
......@@ -317,7 +318,7 @@ FetchResultType ParallelSSAGraphExecutor::Run(
continue;
}
const auto &fetch_list =
boost::get<FetchUnmergedList>(fetch_data[scope_idx]);
BOOST_GET_CONST(FetchUnmergedList, fetch_data[scope_idx]);
PADDLE_ENFORCE_EQ(
fetch_list[fetch_idx].size(),
1,
......
......@@ -30,7 +30,7 @@ using FetchType = paddle::variant<LoDTensor, LoDTensorArray, framework::Vocab>;
using FetchList = std::vector<FetchType>;
using FetchUnmergedList = std::vector<std::vector<FetchType>>;
using FetchResultType = boost::variant<FetchList, FetchUnmergedList>;
using FetchResultType = paddle::variant<FetchList, FetchUnmergedList>;
inline bool data_is_lod_tensor(const FetchType &data) {
if (data.type() == typeid(LoDTensor)) {
......
......@@ -972,37 +972,26 @@ void ParallelExecutor::BCastParamsToDevices(
}
}
FetchResultType ParallelExecutor::Run(
const std::vector<std::string> &fetch_tensors, bool return_merged) {
platform::RecordEvent record_run(
"ParallelExecutor::Run", platform::TracerEventType::UserDefined, 1);
VLOG(3) << "enter ParallelExecutor Run";
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) {
PADDLE_ENFORCE_EQ(fetch_tensors.empty(),
true,
platform::errors::InvalidArgument(
"Cannot fetch data when using CUDA Graph."));
PADDLE_ENFORCE_EQ(
member_->build_strategy_.allow_cuda_graph_capture_,
true,
platform::errors::InvalidArgument(
"You must turn on build_strategy.allow_cuda_graph_capture = True "
"to enable CUDA Graph capturing."));
PADDLE_ENFORCE_EQ(
member_->places_[0],
platform::CUDAGraphCapturingPlace(),
platform::errors::InvalidArgument("The place to capture CUDAGraph is "
"not the same as the place to run."));
}
#endif
FetchUnmergedList ParallelExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
PreludeToRun(fetch_tensors);
platform::RecordBlock b(0);
#ifdef WITH_GPERFTOOLS
if (gProfileStarted) {
ProfilerFlush();
}
#endif
ResetHasFeedGuard reset_has_feed_guard(member_);
ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_),
fetch_tensors,
member_->HasGarbageCollectors());
VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run";
auto fetch_data =
member_->executor_->Run(fetch_tensors, /*return_merged=*/false);
return BOOST_GET(FetchUnmergedList, fetch_data);
}
FetchList ParallelExecutor::RunAndMerge(
const std::vector<std::string> &fetch_tensors) {
PreludeToRun(fetch_tensors);
platform::RecordBlock b(0);
ResetHasFeedGuard reset_has_feed_guard(member_);
......@@ -1011,9 +1000,10 @@ FetchResultType ParallelExecutor::Run(
fetch_tensors,
member_->HasGarbageCollectors());
VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run";
auto fetch_data = member_->executor_->Run(fetch_tensors, return_merged);
return fetch_data;
VLOG(3) << "ParallelExecutor begin to run member_->executor_->RunAndMerge";
auto fetch_data =
member_->executor_->Run(fetch_tensors, /*return_merged=*/true);
return BOOST_GET(FetchList, fetch_data);
}
void ParallelExecutor::RunWithoutFetch(
......@@ -1440,6 +1430,38 @@ std::vector<ir::Graph *> ParallelExecutor::CloneGraphToMultiDevices(
return graphs;
}
void ParallelExecutor::PreludeToRun(
const std::vector<std::string> &fetch_tensors) {
platform::RecordEvent record_run(
"ParallelExecutor::Run", platform::TracerEventType::UserDefined, 1);
VLOG(3) << "enter ParallelExecutor Run";
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) {
PADDLE_ENFORCE_EQ(fetch_tensors.empty(),
true,
platform::errors::InvalidArgument(
"Cannot fetch data when using CUDA Graph."));
PADDLE_ENFORCE_EQ(
member_->build_strategy_.allow_cuda_graph_capture_,
true,
platform::errors::InvalidArgument(
"You must turn on build_strategy.allow_cuda_graph_capture = True "
"to enable CUDA Graph capturing."));
PADDLE_ENFORCE_EQ(
member_->places_[0],
platform::CUDAGraphCapturingPlace(),
platform::errors::InvalidArgument("The place to capture CUDAGraph is "
"not the same as the place to run."));
}
#endif
#ifdef WITH_GPERFTOOLS
if (gProfileStarted) {
ProfilerFlush();
}
#endif
}
void ParallelExecutor::PrepareNCCLCommunicator(Scope *global_scope) {
if (member_->build_strategy_.reduce_ ==
BuildStrategy::ReduceStrategy::kNoReduce) {
......
......@@ -89,8 +89,8 @@ class ParallelExecutor {
void FeedAndSplitTensorIntoLocalScopes(
const std::unordered_map<std::string, LoDTensor> &tensors);
FetchResultType Run(const std::vector<std::string> &fetch_tensors,
bool return_merged = true);
FetchUnmergedList Run(const std::vector<std::string> &fetch_tensors);
FetchList RunAndMerge(const std::vector<std::string> &fetch_tensors);
void RunWithoutFetch(const std::vector<std::string> &skip_eager_vars);
......@@ -126,6 +126,8 @@ class ParallelExecutor {
std::vector<ir::Graph *> CloneGraphToMultiDevices(ir::Graph *graph);
void PreludeToRun(const std::vector<std::string> &fetch_tensors);
void PrepareNCCLCommunicator(Scope *global_scope);
std::vector<ir::Graph *> CompileGraphWithBuildStrategy(
......
......@@ -3225,13 +3225,17 @@ All parameter, weight, gradient are variables in Paddle.
#endif
m.def("set_feed_variable",
static_cast<void (*)(
Scope *, const LoDTensor &, const std::string &, size_t)>(
&framework::SetFeedVariable));
static_cast<void (*)( // NOLINT
Scope *,
const LoDTensor &,
const std::string &,
size_t)>(&framework::SetFeedVariable));
m.def("set_feed_variable",
static_cast<void (*)(
Scope *, const Strings &, const std::string &, size_t)>(
&framework::SetFeedVariable));
static_cast<void (*)( // NOLINT
Scope *,
const Strings &,
const std::string &,
size_t)>(&framework::SetFeedVariable));
m.def("get_fetch_variable",
[](const Scope &scope,
const std::string &var_name,
......@@ -4601,20 +4605,20 @@ All parameter, weight, gradient are variables in Paddle.
[](ParallelExecutor &self,
const std::vector<std::string> &fetch_tensors,
bool return_merged) -> py::object {
paddle::framework::FetchResultType ret;
{
if (return_merged) {
paddle::framework::FetchList ret;
/*gil_scoped_release*/ {
pybind11::gil_scoped_release release;
ret = self.Run(fetch_tensors, return_merged);
ret = self.RunAndMerge(fetch_tensors);
}
// TODO(Ruibiao): Refactor the run interface of PE to avoid use
// boost::get here
if (return_merged) {
return py::cast(
std::move(boost::get<paddle::framework::FetchList>(ret)));
return py::cast(std::move(ret));
} else {
return py::cast(std::move(
boost::get<paddle::framework::FetchUnmergedList>(ret)));
paddle::framework::FetchUnmergedList ret;
/*gil_scoped_release*/ {
pybind11::gil_scoped_release release;
ret = self.Run(fetch_tensors);
}
return py::cast(std::move(ret));
}
})
.def("device_count", &ParallelExecutor::DeviceCount);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册