diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index 137e0dd7708dcc77c3a927979cfb357249f1fdc9..1bd27263f7dad5f733c553c202444ba7cacd2510 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -106,7 +106,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( VLOG(1) << "set num_threads: " << strategy_.num_threads_ << " to run the operators of the graph on each device."; for (size_t i = 0; i < places.size(); ++i) { - executors_.emplace_back(new details::ThreadedSSAGraphExecutor( + executors_.emplace_back(new details::FastThreadedSSAGraphExecutor( strategy_, local_scopes_, {places_[i]}, graphs_.at(i).get())); } } diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.h b/paddle/fluid/framework/details/parallel_ssa_graph_executor.h index 1e421f2a3a51363fe368859f7a34593c8c894077..faf071b05306a49c0049421bc72e4981c0bfc84c 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.h @@ -14,12 +14,12 @@ #pragma once +#include #include #include - #include "ThreadPool.h" +#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" -#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/ir/graph.h" namespace paddle { @@ -48,7 +48,8 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { std::vector places_; std::vector> graphs_; - std::vector> executors_; + std::vector> + executors_; ExceptionHolder exception_holder_; };