提交 f9556dca 编写于 作者: Y Yancey1989

use open_files reader to read multiple files

上级 a6a7b6f1
...@@ -89,14 +89,14 @@ The above codes would generate multiple RecordIO files on your host like: ...@@ -89,14 +89,14 @@ The above codes would generate multiple RecordIO files on your host like:
```bash ```bash
. .
\_mnist.recordio-00000 \_mnist-00000.recordio
|-mnist.recordio-00001 |-mnist-00001.recordio
|-mnist.recordio-00002 |-mnist-00002.recordio
|-mnist.recordio-00003 |-mnist-00003.recordio
|-mnist.recordio-00004 |-mnist-00004.recordio
``` ```
1. read these RecordIO files with `fluid.layers.io.open_recordio_file` 1. open multiple RecordIO files by `fluid.layers.io.open_files`
For a distributed training job, the distributed operator system will schedule trainer process on multiple nodes, For a distributed training job, the distributed operator system will schedule trainer process on multiple nodes,
each trainer process reads parts of the whole training data, we usually take the following approach to make the training each trainer process reads parts of the whole training data, we usually take the following approach to make the training
...@@ -113,10 +113,12 @@ def gen_train_list(file_pattern, trainers, trainer_id): ...@@ -113,10 +113,12 @@ def gen_train_list(file_pattern, trainers, trainer_id):
trainers = int(os.getenv("TRAINERS")) trainers = int(os.getenv("TRAINERS"))
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID")) trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID"))
data_file = fluid.layers.io.open_recordio_file( data_file = fluid.layers.io.open_files(
filename=gen_train_list("./mnist.recordio*", trainers, trainer_id), filenames=gen_train_list("./mnist-[0-9]*.recordio", 2, 0),
shapes=[(-1, 784),(-1, 1)], thread_num=1,
lod_levels=[0, 0], shapes=[(-1, 784),(-1, 1)],
dtypes=["float32", "int32"]) lod_levels=[0, 0],
data_file = fluid.layers.io.batch(data_file, batch_size=4) dtypes=["float32", "int32"])
img, label = fluid.layers.io.read_file(data_files)
...
``` ```
...@@ -65,22 +65,20 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { ...@@ -65,22 +65,20 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
static_cast<int>(shape_concat.size()), static_cast<int>(shape_concat.size()),
"The accumulate of all ranks should be equal to the " "The accumulate of all ranks should be equal to the "
"shape concat's length."); "shape concat's length.");
auto filenames = Attr<std::vector<std::string>>("filenames"); std::string filename = Attr<std::string>("filename");
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
for (auto& fn : filenames) {
out->Reset( out->Reset(new RecordIOFileReader<true>(
new RecordIOFileReader<true>(fn, RestoreShapes(shape_concat, ranks))); filename, RestoreShapes(shape_concat, ranks)));
}
} }
}; };
class CreateRecordIOReaderOpMaker : public FileReaderMakerBase { class CreateRecordIOReaderOpMaker : public FileReaderMakerBase {
protected: protected:
void Apply() override { void Apply() override {
AddAttr<std::vector<std::string>>("filenames", AddAttr<std::string>("filename", "The filename of record io reader");
"The filenames of record io reader");
AddComment(R"DOC( AddComment(R"DOC(
CreateRecordIOReader Operator CreateRecordIOReader Operator
......
...@@ -21,7 +21,7 @@ from ..layer_helper import LayerHelper ...@@ -21,7 +21,7 @@ from ..layer_helper import LayerHelper
from ..executor import global_scope from ..executor import global_scope
__all__ = [ __all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_files', 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer', 'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer',
'random_data_generator', 'Preprocessor' 'random_data_generator', 'Preprocessor'
] ]
...@@ -291,12 +291,12 @@ def _copy_reader_create_op_(block, op): ...@@ -291,12 +291,12 @@ def _copy_reader_create_op_(block, op):
return new_op return new_op
def open_recordio_files(filenames, def open_recordio_file(filename,
shapes, shapes,
lod_levels, lod_levels,
dtypes, dtypes,
pass_num=1, pass_num=1,
for_parallel=True): for_parallel=True):
""" """
Open a RecordIO file Open a RecordIO file
...@@ -304,7 +304,7 @@ def open_recordio_files(filenames, ...@@ -304,7 +304,7 @@ def open_recordio_files(filenames,
Via the Reader Variable, we can get data from the given RecordIO file. Via the Reader Variable, we can get data from the given RecordIO file.
Args: Args:
filename(str) or list(str): The RecordIO file's name. filename(str): The RecordIO file's name.
shapes(list): List of tuples which declaring data shapes. shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level. lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type. dtypes(list): List of strs which declaring data type.
...@@ -336,8 +336,6 @@ def open_recordio_files(filenames, ...@@ -336,8 +336,6 @@ def open_recordio_files(filenames,
ranks.append(len(shape)) ranks.append(len(shape))
var_name = unique_name('open_recordio_file') var_name = unique_name('open_recordio_file')
if isinstance(filenames, str):
filenames = [filenames]
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name) startup_var = startup_blk.create_var(name=var_name)
...@@ -347,7 +345,7 @@ def open_recordio_files(filenames, ...@@ -347,7 +345,7 @@ def open_recordio_files(filenames,
attrs={ attrs={
'shape_concat': shape_concat, 'shape_concat': shape_concat,
'lod_levels': lod_levels, 'lod_levels': lod_levels,
'filenames': filenames, 'filename': filename,
'ranks': ranks 'ranks': ranks
}) })
......
...@@ -12,9 +12,12 @@ ...@@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import core import core
import contextlib import contextlib
__all__ = ['convert_reader_to_recordio_file'] __all__ = [
'convert_reader_to_recordio_file', 'convert_reader_to_recordio_files'
]
@contextlib.contextmanager @contextlib.contextmanager
...@@ -48,7 +51,7 @@ def convert_reader_to_recordio_file( ...@@ -48,7 +51,7 @@ def convert_reader_to_recordio_file(
def convert_reader_to_recordio_files( def convert_reader_to_recordio_files(
filename_suffix, filename,
batch_per_file, batch_per_file,
reader_creator, reader_creator,
feeder, feeder,
...@@ -57,13 +60,16 @@ def convert_reader_to_recordio_files( ...@@ -57,13 +60,16 @@ def convert_reader_to_recordio_files(
feed_order=None): feed_order=None):
if feed_order is None: if feed_order is None:
feed_order = feeder.feed_names feed_order = feeder.feed_names
f_name, f_ext = os.path.splitext(filename)
assert (f_ext == ".recordio")
lines = [] lines = []
f_idx = 0 f_idx = 0
counter = 0 counter = 0
for idx, batch in enumerate(reader_creator()): for idx, batch in enumerate(reader_creator()):
lines.append(batch) lines.append(batch)
if idx >= batch_per_file and idx % batch_per_file == 0: if idx >= batch_per_file and idx % batch_per_file == 0:
filename = "%s-%05d" % (filename_suffix, f_idx) filename = "%s-%05d%s" % (f_name, f_idx, f_ext)
with create_recordio_writer(filename, compressor, with create_recordio_writer(filename, compressor,
max_num_records) as writer: max_num_records) as writer:
for l in lines: for l in lines:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册