未验证 提交 8c2eba71 编写于 作者: Y yuyang18

Refine demo

上级 9a570fb9
...@@ -168,7 +168,13 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -168,7 +168,13 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
for (size_t i = 0; i < fetch_tensors.size(); ++i) { for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i]; auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars.at(var_name);
auto fetched_var_it = fetched_vars.find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
"Cannot find fetched variable.(Perhaps the main_program "
"is not set to ParallelExecutor)");
auto &vars = fetched_var_it->second;
auto *op = new FetchOpHandle(fetch_data, i, &local_scopes_); auto *op = new FetchOpHandle(fetch_data, i, &local_scopes_);
fetch_ops->emplace_back(op); fetch_ops->emplace_back(op);
......
...@@ -36,7 +36,7 @@ def network(is_train): ...@@ -36,7 +36,7 @@ def network(is_train):
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax') prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label) loss = fluid.layers.cross_entropy(input=prediction, label=label)
return fluid.layers.mean(loss), queue return fluid.layers.mean(loss), queue, reader
def pipe_reader_to_queue(reader_creator, queue): def pipe_reader_to_queue(reader_creator, queue):
...@@ -70,27 +70,46 @@ def main(): ...@@ -70,27 +70,46 @@ def main():
with fluid.program_guard(train_prog, startup_prog): with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
loss, train_queue = network(True) loss, train_queue, train_reader = network(True)
adam = fluid.optimizer.Adam(learning_rate=0.01) adam = fluid.optimizer.Adam(learning_rate=0.01)
adam.minimize(loss) adam.minimize(loss)
test_prog = fluid.Program() test_prog = fluid.Program()
with fluid.program_guard(test_prog, fluid.Program()): with fluid.program_guard(test_prog, fluid.Program()):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
test_loss, test_queue = network(False) test_loss, test_queue, test_reader = network(False)
fluid.Executor(fluid.CUDAPlace(0)).run(startup_prog) fluid.Executor(fluid.CUDAPlace(0)).run(startup_prog)
trainer = fluid.ParallelExecutor(use_cuda=True, loss_name=loss.name) trainer = fluid.ParallelExecutor(
tester = fluid.ParallelExecutor(use_cuda=True, share_vars_from=trainer) use_cuda=True, loss_name=loss.name, main_program=train_prog)
tester = fluid.ParallelExecutor(
use_cuda=True, share_vars_from=trainer, main_program=test_prog)
for epoch_id in xrange(10): for epoch_id in xrange(10):
pipe_reader_to_queue(paddle.batch(mnist.train(), 32), train_queue) train_data_thread = pipe_reader_to_queue(
pipe_reader_to_queue(paddle.batch(mnist.test(), 32), test_queue) paddle.batch(mnist.train(), 32), train_queue)
try: try:
print 'train_loss', numpy.array(trainer.run(fetch_list=[loss.name])) while True:
print 'train_loss', numpy.array(
trainer.run(fetch_list=[loss.name]))
except fluid.core.EOFException: except fluid.core.EOFException:
print 'End of epoch', epoch_id print 'End of epoch', epoch_id
train_reader.reset()
train_data_thread.join()
test_data_thread = pipe_reader_to_queue(
paddle.batch(mnist.train(), 32), test_queue)
try:
while True:
print numpy.array(tester.run(fetch_list=[test_loss.name]))
except fluid.core.EOFException:
print 'End of testing'
test_reader.reset()
test_data_thread.join()
break
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册