From a6d468a26500c36633a34833fc10e6a9fc23c416 Mon Sep 17 00:00:00 2001 From: chengduo Date: Tue, 16 Jul 2019 13:17:23 +0800 Subject: [PATCH] fix PE fetch bug (#18644) test=develop --- .../framework/details/fast_threaded_ssa_graph_executor.cc | 5 ++++- .../fluid/framework/details/threaded_ssa_graph_executor.cc | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 11052273d2..7daab6dac1 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include #include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" @@ -124,7 +125,9 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps( std::unordered_map> *op_deps, std::vector *fetch_ops, std::vector *ready_fetch_ops) { - for (auto &fetch_var_name : fetch_tensors) { + std::unordered_set fetch_tensor_set(fetch_tensors.begin(), + fetch_tensors.end()); + for (auto &fetch_var_name : fetch_tensor_set) { for (auto &var_map : graph_->Get(kGraphVars)) { auto it = var_map.find(fetch_var_name); if (it != var_map.end()) { diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index ed9d7d991f..db28e1fe20 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" - #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/platform/profiler.h" @@ -157,7 +156,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( FeedFetchList *fetch_data) { std::unordered_map> fetched_vars; std::unordered_set local_ready_vars; - for (auto &fetch_var_name : fetch_tensors) { + std::unordered_set fetch_tensor_set(fetch_tensors.begin(), + fetch_tensors.end()); + for (auto &fetch_var_name : fetch_tensor_set) { for (auto &var_map : graph_->Get(details::kGraphVars)) { auto it = var_map.find(fetch_var_name); if (it != var_map.end()) { -- GitLab