未验证 提交 e8eb81ca 编写于 作者: Y yuyang18

Just train two batch

上级 bcb1516d
...@@ -520,7 +520,7 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None): ...@@ -520,7 +520,7 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None):
startup_var = startup_blk.create_var(name=reader_name) startup_var = startup_blk.create_var(name=reader_name)
startup_blk.append_op( startup_blk.append_op(
type='create_py_reader', type='create_py_reader',
inputs={'blocking_queue': queue_name}, inputs={'blocking_queue': [queue_name]},
outputs={'Out': [startup_var]}, outputs={'Out': [startup_var]},
attrs={ attrs={
'shape_concat': shape_concat, 'shape_concat': shape_concat,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.dataset.mnist as mnist import paddle.dataset.mnist as mnist
import paddle import paddle
import paddle.v2
import threading import threading
import numpy import numpy
...@@ -91,7 +92,8 @@ def main(): ...@@ -91,7 +92,8 @@ def main():
for epoch_id in xrange(10): for epoch_id in xrange(10):
train_data_thread = pipe_reader_to_queue( train_data_thread = pipe_reader_to_queue(
paddle.batch(mnist.train(), 32), train_queue) paddle.batch(paddle.v2.reader.firstn(mnist.train(), 32), 64),
train_queue)
try: try:
while True: while True:
print 'train_loss', numpy.array( print 'train_loss', numpy.array(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册