From 9af870854e99c4eba22506b085cdb1b521f70f20 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 27 Mar 2018 14:30:58 +0800 Subject: [PATCH] Use heap variables --- paddle/fluid/framework/details/op_handle_base.h | 10 +++++++++- .../framework/details/threaded_ssa_graph_executor.cc | 9 ++++----- .../fluid/tests/unittests/test_parallel_executor.py | 3 +++ 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 99d89684867..78f566c0356 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -16,11 +16,17 @@ #include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/macros.h" + namespace paddle { namespace framework { namespace details { -struct OpHandleBase { +class OpHandleBase { + private: + DISABLE_COPY_AND_ASSIGN(OpHandleBase); + + public: std::vector inputs_; std::vector outputs_; std::unordered_map events_; #endif + OpHandleBase() {} + std::string DebugString() const; virtual std::string Name() const = 0; diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 7cfd6683796..41034e9f059 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -67,7 +67,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( } // Step 2. Insert FetchOps - std::vector fetch_ops; + std::vector> fetch_ops; std::vector dummy_vars; FeedFetchList fetch_data(fetch_tensors.size()); @@ -84,9 +84,9 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( for (size_t i = 0; i < fetch_tensors.size(); ++i) { auto &var_name = fetch_tensors[i]; - auto &vars = fetched_vars[var_name]; - fetch_ops.emplace_back(&fetch_data, i, &local_scopes_); - details::FetchOpHandle *op = &fetch_ops.back(); + auto &vars = fetched_vars.at(var_name); + auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_); + fetch_ops.emplace_back(op); // FIXME: Use new device context for (auto &p : places_) { @@ -138,7 +138,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( for (auto &op : pending_ops) { VLOG(10) << op.first->DebugString(); } - // keep waiting the ready variables continue; } diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index 2e61eca0688..a5eea30f87a 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -231,6 +231,9 @@ class TestMNIST(TestParallelExecutorBase): class TestResnet(TestParallelExecutorBase): @classmethod def setUpClass(cls): + import os + if os.path.exists('./flowers.recordio'): + return with fluid.program_guard(fluid.Program(), fluid.Program()): reader = paddle.batch(flowers.train(), batch_size=4) feeder = fluid.DataFeeder( -- GitLab