diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 99d896848675c96d806ff36258e81d346a8d9db3..78f566c0356895f8a1828994b02f2e1a5f71755c 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -16,11 +16,17 @@ #include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/macros.h" + namespace paddle { namespace framework { namespace details { -struct OpHandleBase { +class OpHandleBase { + private: + DISABLE_COPY_AND_ASSIGN(OpHandleBase); + + public: std::vector inputs_; std::vector outputs_; std::unordered_map events_; #endif + OpHandleBase() {} + std::string DebugString() const; virtual std::string Name() const = 0; diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 7cfd66837966836774113229da3900c333f1c962..41034e9f0598854102fdc439bd586aa6f837991a 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -67,7 +67,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( } // Step 2. Insert FetchOps - std::vector fetch_ops; + std::vector> fetch_ops; std::vector dummy_vars; FeedFetchList fetch_data(fetch_tensors.size()); @@ -84,9 +84,9 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( for (size_t i = 0; i < fetch_tensors.size(); ++i) { auto &var_name = fetch_tensors[i]; - auto &vars = fetched_vars[var_name]; - fetch_ops.emplace_back(&fetch_data, i, &local_scopes_); - details::FetchOpHandle *op = &fetch_ops.back(); + auto &vars = fetched_vars.at(var_name); + auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_); + fetch_ops.emplace_back(op); // FIXME: Use new device context for (auto &p : places_) { @@ -138,7 +138,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( for (auto &op : pending_ops) { VLOG(10) << op.first->DebugString(); } - // keep waiting the ready variables continue; } diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index 2e61eca0688fd5966f116341bfdfef1ef2630344..a5eea30f87a0a800c694ff605084851e73e3b13c 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -231,6 +231,9 @@ class TestMNIST(TestParallelExecutorBase): class TestResnet(TestParallelExecutorBase): @classmethod def setUpClass(cls): + import os + if os.path.exists('./flowers.recordio'): + return with fluid.program_guard(fluid.Program(), fluid.Program()): reader = paddle.batch(flowers.train(), batch_size=4) feeder = fluid.DataFeeder(