提交 d84ddcf1 编写于 作者: Y Yu Yang

Stash

上级 193c0a7e
...@@ -45,7 +45,7 @@ struct ExecutorPrepareContext { ...@@ -45,7 +45,7 @@ struct ExecutorPrepareContext {
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
static void CreateTensor(Variable* var, proto::VarType::Type var_type) { void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) { if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>(); var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarType::SELECTED_ROWS) { } else if (var_type == proto::VarType::SELECTED_ROWS) {
...@@ -284,12 +284,12 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -284,12 +284,12 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
if (var->Persistable()) { if (var->Persistable()) {
auto* ptr = scope->Var(var->Name()); auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name() VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr; << " global, which pointer is " << ptr;
} else { } else {
auto* ptr = local_scope->Var(var->Name()); auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name() VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr; << " locally, which pointer is " << ptr;
} }
...@@ -297,7 +297,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -297,7 +297,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} else { } else {
for (auto& var : block.AllVars()) { for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name()); auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is " VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr; << ptr;
} }
......
...@@ -59,5 +59,7 @@ class Executor { ...@@ -59,5 +59,7 @@ class Executor {
const platform::Place place_; const platform::Place place_;
}; };
extern void InitializeVariable(Variable* var, proto::VarType::Type var_type);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -84,14 +84,14 @@ struct ComputationOpHandle : public OpHandle { ...@@ -84,14 +84,14 @@ struct ComputationOpHandle : public OpHandle {
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope, explicit ComputationOpHandle(const OpDesc &op_desc, platform::Place place)
platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)), : op_(framework::OpRegistry::CreateOp(op_desc)),
scope_(scope), scope_(nullptr),
place_(place) {} place_(place) {}
void Run() override { void Run() override {
// Wait other op if necessary // Wait other op if necessary
LOG(INFO) << DebugString();
auto *cur_ctx = dev_ctx_[place_]; auto *cur_ctx = dev_ctx_[place_];
for (auto *in : inputs_) { for (auto *in : inputs_) {
if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) { if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
...@@ -240,8 +240,7 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -240,8 +240,7 @@ void ParallelExecutor::ConstructDependencyGraph(
} }
for (auto &pair : member_->local_scopes_) { for (auto &pair : member_->local_scopes_) {
member_->ops_.emplace_back( member_->ops_.emplace_back(new ComputationOpHandle(*op, pair.first));
new ComputationOpHandle(*op, pair.second, pair.first));
auto *op_handle = member_->ops_.back().get(); auto *op_handle = member_->ops_.back().get();
op_handle->dev_ctx_[pair.first] = const_cast<platform::DeviceContext *>( op_handle->dev_ctx_[pair.first] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(pair.first)); platform::DeviceContextPool::Instance().Get(pair.first));
......
...@@ -25,7 +25,9 @@ class RecordIOFileReader : public framework::FileReader { ...@@ -25,7 +25,9 @@ class RecordIOFileReader : public framework::FileReader {
: FileReader(shapes), : FileReader(shapes),
scanner_(filename), scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get( dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) {} platform::CPUPlace())) {
LOG(INFO) << "Creating file reader" << filename;
}
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_); *out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
......
...@@ -14,16 +14,33 @@ ...@@ -14,16 +14,33 @@
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.v2 as paddle
import paddle.v2.dataset.mnist as mnist
class ParallelExecutor(unittest.TestCase): class ParallelExecutor(unittest.TestCase):
def setUp(self):
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=32)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
name='image', shape=[784]),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist.recordio', reader, feeder)
def test_main(self): def test_main(self):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
reader = fluid.layers.open_recordio_file( reader = fluid.layers.open_recordio_file(
filename='tmp', filename='./mnist.recordio',
shapes=[[-1, 784], [-1, 1]], shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64']) dtypes=['float32', 'int64'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册