提交 7cd6de37 编写于 作者: Y Yancey1989

fix cpu test=develop

上级 bd0d44af
...@@ -36,7 +36,6 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph( ...@@ -36,7 +36,6 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
for (auto &op : graph->Get<GraphOps>(kGraphOps)) { for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
auto &dev_ctx = op->DeviceContext(); auto &dev_ctx = op->DeviceContext();
auto &p = dev_ctx.begin()->first; auto &p = dev_ctx.begin()->first;
#ifdef PADDLE_WITH_CUDA
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &dev_ops = graphs[dev_id]->Get<GraphOps>(kGraphOps); auto &dev_ops = graphs[dev_id]->Get<GraphOps>(kGraphOps);
auto &dev_dummys = graphs[dev_id]->Get<GraphDepVars>(kGraphDepVars); auto &dev_dummys = graphs[dev_id]->Get<GraphDepVars>(kGraphDepVars);
...@@ -59,9 +58,6 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph( ...@@ -59,9 +58,6 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
graphs[dev_id]->AddNode(graph->RemoveNode(var->Node()).release()); graphs[dev_id]->AddNode(graph->RemoveNode(var->Node()).release());
} }
} }
#else
PADDLE_THROW("Parallel Graph Execution only support CUDAPlace.");
#endif
} }
for (size_t dev_id = 0; dev_id < places.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < places.size(); ++dev_id) {
......
...@@ -304,6 +304,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -304,6 +304,7 @@ ParallelExecutor::ParallelExecutor(
} }
if (build_strategy.enable_parallel_graph_) { if (build_strategy.enable_parallel_graph_) {
#ifdef PADDLE_WITH_CUDA
auto parallel_graph = auto parallel_graph =
details::SeparateMultiDevicesGraph(member_->places_, std::move(graph)); details::SeparateMultiDevicesGraph(member_->places_, std::move(graph));
auto seq_allreduce_pass = auto seq_allreduce_pass =
...@@ -319,6 +320,10 @@ ParallelExecutor::ParallelExecutor( ...@@ -319,6 +320,10 @@ ParallelExecutor::ParallelExecutor(
member_->executor_.reset(new details::ParallelSSAGraphExecutor( member_->executor_.reset(new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, exec_strategy, member_->local_scopes_, member_->places_,
std::move(parallel_graph))); std::move(parallel_graph)));
#else
PADDLE_THROW(
"Paddle should be compiled with CUDA for ParallelGraph Execution.");
#endif
} else { } else {
if (exec_strategy.type_ == ExecutionStrategy::kDefault) { if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册