提交 d84ddcf1 编写于 作者: Y Yu Yang

Stash

上级 193c0a7e
......@@ -45,7 +45,7 @@ struct ExecutorPrepareContext {
Executor::Executor(const platform::Place& place) : place_(place) {}
static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarType::SELECTED_ROWS) {
......@@ -284,12 +284,12 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
......@@ -297,7 +297,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} else {
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
......
......@@ -59,5 +59,7 @@ class Executor {
const platform::Place place_;
};
extern void InitializeVariable(Variable* var, proto::VarType::Type var_type);
} // namespace framework
} // namespace paddle
......@@ -84,14 +84,14 @@ struct ComputationOpHandle : public OpHandle {
Scope *scope_;
platform::Place place_;
explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
platform::Place place)
explicit ComputationOpHandle(const OpDesc &op_desc, platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
scope_(scope),
scope_(nullptr),
place_(place) {}
void Run() override {
// Wait other op if necessary
LOG(INFO) << DebugString();
auto *cur_ctx = dev_ctx_[place_];
for (auto *in : inputs_) {
if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
......@@ -240,8 +240,7 @@ void ParallelExecutor::ConstructDependencyGraph(
}
for (auto &pair : member_->local_scopes_) {
member_->ops_.emplace_back(
new ComputationOpHandle(*op, pair.second, pair.first));
member_->ops_.emplace_back(new ComputationOpHandle(*op, pair.first));
auto *op_handle = member_->ops_.back().get();
op_handle->dev_ctx_[pair.first] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(pair.first));
......
......@@ -25,7 +25,9 @@ class RecordIOFileReader : public framework::FileReader {
: FileReader(shapes),
scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) {}
platform::CPUPlace())) {
LOG(INFO) << "Creating file reader" << filename;
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
......
......@@ -14,16 +14,33 @@
import unittest
import paddle.fluid as fluid
import paddle.v2 as paddle
import paddle.v2.dataset.mnist as mnist
class ParallelExecutor(unittest.TestCase):
def setUp(self):
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=32)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
name='image', shape=[784]),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist.recordio', reader, feeder)
def test_main(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
reader = fluid.layers.open_recordio_file(
filename='tmp',
filename='./mnist.recordio',
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册