提交 9af87085 编写于 作者: Y Yu Yang

Use heap variables

上级 22276329
...@@ -16,11 +16,17 @@ ...@@ -16,11 +16,17 @@
#include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
struct OpHandleBase { class OpHandleBase {
private:
DISABLE_COPY_AND_ASSIGN(OpHandleBase);
public:
std::vector<VarHandleBase *> inputs_; std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_; std::vector<VarHandleBase *> outputs_;
std::unordered_map<platform::Place, platform::DeviceContext *, std::unordered_map<platform::Place, platform::DeviceContext *,
...@@ -31,6 +37,8 @@ struct OpHandleBase { ...@@ -31,6 +37,8 @@ struct OpHandleBase {
std::unordered_map<int, cudaEvent_t> events_; std::unordered_map<int, cudaEvent_t> events_;
#endif #endif
OpHandleBase() {}
std::string DebugString() const; std::string DebugString() const;
virtual std::string Name() const = 0; virtual std::string Name() const = 0;
......
...@@ -67,7 +67,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -67,7 +67,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
} }
// Step 2. Insert FetchOps // Step 2. Insert FetchOps
std::vector<FetchOpHandle> fetch_ops; std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
std::vector<DummyVarHandle> dummy_vars; std::vector<DummyVarHandle> dummy_vars;
FeedFetchList fetch_data(fetch_tensors.size()); FeedFetchList fetch_data(fetch_tensors.size());
...@@ -84,9 +84,9 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -84,9 +84,9 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (size_t i = 0; i < fetch_tensors.size(); ++i) { for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i]; auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars[var_name]; auto &vars = fetched_vars.at(var_name);
fetch_ops.emplace_back(&fetch_data, i, &local_scopes_); auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_);
details::FetchOpHandle *op = &fetch_ops.back(); fetch_ops.emplace_back(op);
// FIXME: Use new device context // FIXME: Use new device context
for (auto &p : places_) { for (auto &p : places_) {
...@@ -138,7 +138,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -138,7 +138,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &op : pending_ops) { for (auto &op : pending_ops) {
VLOG(10) << op.first->DebugString(); VLOG(10) << op.first->DebugString();
} }
// keep waiting the ready variables // keep waiting the ready variables
continue; continue;
} }
......
...@@ -231,6 +231,9 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -231,6 +231,9 @@ class TestMNIST(TestParallelExecutorBase):
class TestResnet(TestParallelExecutorBase): class TestResnet(TestParallelExecutorBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
import os
if os.path.exists('./flowers.recordio'):
return
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(flowers.train(), batch_size=4) reader = paddle.batch(flowers.train(), batch_size=4)
feeder = fluid.DataFeeder( feeder = fluid.DataFeeder(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册