未验证 提交 7c194acc 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #9929 from JiayiFeng/make_for_parallel_default_true

make for_parallel default True
...@@ -21,7 +21,7 @@ from ..executor import global_scope ...@@ -21,7 +21,7 @@ from ..executor import global_scope
__all__ = [ __all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'open_files', 'read_file', 'shuffle', 'double_buffer' 'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer'
] ]
...@@ -290,7 +290,7 @@ def open_recordio_file(filename, ...@@ -290,7 +290,7 @@ def open_recordio_file(filename,
lod_levels, lod_levels,
dtypes, dtypes,
pass_num=1, pass_num=1,
for_parallel=False): for_parallel=True):
""" """
Open a RecordIO file Open a RecordIO file
...@@ -364,7 +364,7 @@ def open_files(filenames, ...@@ -364,7 +364,7 @@ def open_files(filenames,
thread_num, thread_num,
buffer_size=None, buffer_size=None,
pass_num=1, pass_num=1,
for_parallel=False): for_parallel=True):
""" """
Open files Open files
...@@ -476,6 +476,11 @@ def shuffle(reader, buffer_size): ...@@ -476,6 +476,11 @@ def shuffle(reader, buffer_size):
'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)}) 'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)})
def batch(reader, batch_size):
return __create_unshared_decorated_reader__(
'create_batch_reader', reader, {'batch_size': int(batch_size)})
def double_buffer(reader, place=None): def double_buffer(reader, place=None):
attrs = dict() attrs = dict()
if place is not None: if place is not None:
......
...@@ -69,7 +69,6 @@ class TestMultipleReader(unittest.TestCase): ...@@ -69,7 +69,6 @@ class TestMultipleReader(unittest.TestCase):
break break
batch_count += 1 batch_count += 1
self.assertLessEqual(img_val.shape[0], self.batch_size) self.assertLessEqual(img_val.shape[0], self.batch_size)
data_files.reset()
self.assertEqual(batch_count, self.num_batch * 3) self.assertEqual(batch_count, self.num_batch * 3)
def test_main(self): def test_main(self):
......
...@@ -43,9 +43,8 @@ class TestMultipleReader(unittest.TestCase): ...@@ -43,9 +43,8 @@ class TestMultipleReader(unittest.TestCase):
filename='./mnist.recordio', filename='./mnist.recordio',
shapes=[(-1, 784), (-1, 1)], shapes=[(-1, 784), (-1, 1)],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64']) dtypes=['float32', 'int64'],
data_file = fluid.layers.io.multi_pass( pass_num=self.pass_num)
reader=data_file, pass_num=self.pass_num)
img, label = fluid.layers.read_file(data_file) img, label = fluid.layers.read_file(data_file)
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
...@@ -65,5 +64,4 @@ class TestMultipleReader(unittest.TestCase): ...@@ -65,5 +64,4 @@ class TestMultipleReader(unittest.TestCase):
break break
batch_count += 1 batch_count += 1
self.assertLessEqual(img_val.shape[0], self.batch_size) self.assertLessEqual(img_val.shape[0], self.batch_size)
data_file.reset()
self.assertEqual(batch_count, self.num_batch * self.pass_num) self.assertEqual(batch_count, self.num_batch * self.pass_num)
...@@ -74,12 +74,12 @@ class TestRecordIO(unittest.TestCase): ...@@ -74,12 +74,12 @@ class TestRecordIO(unittest.TestCase):
avg_loss_np.append(tmp) avg_loss_np.append(tmp)
batch_id += 1 batch_id += 1
data_file.reset()
self.assertEqual(batch_id, self.num_batches) self.assertEqual(batch_id, self.num_batches)
self.assertLess(avg_loss_np[-1], avg_loss_np[0]) self.assertLess(avg_loss_np[-1], avg_loss_np[0])
def test_shuffle_reader(self): def test_shuffle_reader(self):
self.test_main(decorator_callback=lambda reader: fluid.layers.io.shuffle(reader, buffer_size=200)) self.test_main(decorator_callback=lambda reader: fluid.layers.io.shuffle(
reader, buffer_size=200))
def test_double_buffer_reader(self): def test_double_buffer_reader(self):
self.test_main(decorator_callback=lambda reader: fluid.layers.io.double_buffer(reader, self.test_main(decorator_callback=lambda reader: fluid.layers.io.double_buffer(reader,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册