diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 11052273d2849b4b8836c55466e205b8fd0789de..7daab6dac19768e1d35c84bfd78d319c8a62512b 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include #include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" @@ -124,7 +125,9 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps( std::unordered_map> *op_deps, std::vector *fetch_ops, std::vector *ready_fetch_ops) { - for (auto &fetch_var_name : fetch_tensors) { + std::unordered_set fetch_tensor_set(fetch_tensors.begin(), + fetch_tensors.end()); + for (auto &fetch_var_name : fetch_tensor_set) { for (auto &var_map : graph_->Get(kGraphVars)) { auto it = var_map.find(fetch_var_name); if (it != var_map.end()) { diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index ed9d7d991f830428f79a56a440cb9c9a5ad86509..db28e1fe202116f49e0266a7bc24ddfb351c8bb4 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" - #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/platform/profiler.h" @@ -157,7 +156,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( FeedFetchList *fetch_data) { std::unordered_map> fetched_vars; std::unordered_set local_ready_vars; - for (auto &fetch_var_name : fetch_tensors) { + std::unordered_set fetch_tensor_set(fetch_tensors.begin(), + fetch_tensors.end()); + for (auto &fetch_var_name : fetch_tensor_set) { for (auto &var_map : graph_->Get(details::kGraphVars)) { auto it = var_map.find(fetch_var_name); if (it != var_map.end()) {