提交 0dff1e41 编写于 作者: J JiayiFeng

Fix bugs

上级 1b5a6351
......@@ -21,7 +21,7 @@ from ..executor import global_scope
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'open_files', 'read_file', 'shuffle', 'double_buffer'
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer'
]
......@@ -469,6 +469,11 @@ def shuffle(reader, buffer_size):
'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)})
def batch(reader, batch_size):
return __create_unshared_decorated_reader__(
'create_batch_reader', reader, {'batch_size': int(batch_size)})
def double_buffer(reader, place=None):
attrs = dict()
if place is not None:
......
......@@ -43,9 +43,8 @@ class TestMultipleReader(unittest.TestCase):
filename='./mnist.recordio',
shapes=[(-1, 784), (-1, 1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
data_file = fluid.layers.io.multi_pass(
reader=data_file, pass_num=self.pass_num)
dtypes=['float32', 'int64'],
pass_num=self.pass_num)
img, label = fluid.layers.read_file(data_file)
if fluid.core.is_compiled_with_cuda():
......@@ -65,5 +64,4 @@ class TestMultipleReader(unittest.TestCase):
break
batch_count += 1
self.assertLessEqual(img_val.shape[0], self.batch_size)
data_file.reset()
self.assertEqual(batch_count, self.num_batch * self.pass_num)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册