提交 f9556dca 编写于 作者: Y Yancey1989

use open_files reader to read multiple files

上级 a6a7b6f1
......@@ -89,14 +89,14 @@ The above codes would generate multiple RecordIO files on your host like:
```bash
.
\_mnist.recordio-00000
|-mnist.recordio-00001
|-mnist.recordio-00002
|-mnist.recordio-00003
|-mnist.recordio-00004
\_mnist-00000.recordio
|-mnist-00001.recordio
|-mnist-00002.recordio
|-mnist-00003.recordio
|-mnist-00004.recordio
```
1. read these RecordIO files with `fluid.layers.io.open_recordio_file`
1. open multiple RecordIO files by `fluid.layers.io.open_files`
For a distributed training job, the distributed operator system will schedule trainer process on multiple nodes,
each trainer process reads parts of the whole training data, we usually take the following approach to make the training
......@@ -113,10 +113,12 @@ def gen_train_list(file_pattern, trainers, trainer_id):
trainers = int(os.getenv("TRAINERS"))
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID"))
data_file = fluid.layers.io.open_recordio_file(
filename=gen_train_list("./mnist.recordio*", trainers, trainer_id),
shapes=[(-1, 784),(-1, 1)],
lod_levels=[0, 0],
dtypes=["float32", "int32"])
data_file = fluid.layers.io.batch(data_file, batch_size=4)
data_file = fluid.layers.io.open_files(
filenames=gen_train_list("./mnist-[0-9]*.recordio", 2, 0),
thread_num=1,
shapes=[(-1, 784),(-1, 1)],
lod_levels=[0, 0],
dtypes=["float32", "int32"])
img, label = fluid.layers.io.read_file(data_files)
...
```
......@@ -65,22 +65,20 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
static_cast<int>(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
auto filenames = Attr<std::vector<std::string>>("filenames");
std::string filename = Attr<std::string>("filename");
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
for (auto& fn : filenames) {
out->Reset(
new RecordIOFileReader<true>(fn, RestoreShapes(shape_concat, ranks)));
}
out->Reset(new RecordIOFileReader<true>(
filename, RestoreShapes(shape_concat, ranks)));
}
};
class CreateRecordIOReaderOpMaker : public FileReaderMakerBase {
protected:
void Apply() override {
AddAttr<std::vector<std::string>>("filenames",
"The filenames of record io reader");
AddAttr<std::string>("filename", "The filename of record io reader");
AddComment(R"DOC(
CreateRecordIOReader Operator
......
......@@ -21,7 +21,7 @@ from ..layer_helper import LayerHelper
from ..executor import global_scope
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_files',
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer',
'random_data_generator', 'Preprocessor'
]
......@@ -291,12 +291,12 @@ def _copy_reader_create_op_(block, op):
return new_op
def open_recordio_files(filenames,
shapes,
lod_levels,
dtypes,
pass_num=1,
for_parallel=True):
def open_recordio_file(filename,
shapes,
lod_levels,
dtypes,
pass_num=1,
for_parallel=True):
"""
Open a RecordIO file
......@@ -304,7 +304,7 @@ def open_recordio_files(filenames,
Via the Reader Variable, we can get data from the given RecordIO file.
Args:
filename(str) or list(str): The RecordIO file's name.
filename(str): The RecordIO file's name.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
......@@ -336,8 +336,6 @@ def open_recordio_files(filenames,
ranks.append(len(shape))
var_name = unique_name('open_recordio_file')
if isinstance(filenames, str):
filenames = [filenames]
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
......@@ -347,7 +345,7 @@ def open_recordio_files(filenames,
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'filenames': filenames,
'filename': filename,
'ranks': ranks
})
......
......@@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import core
import contextlib
__all__ = ['convert_reader_to_recordio_file']
__all__ = [
'convert_reader_to_recordio_file', 'convert_reader_to_recordio_files'
]
@contextlib.contextmanager
......@@ -48,7 +51,7 @@ def convert_reader_to_recordio_file(
def convert_reader_to_recordio_files(
filename_suffix,
filename,
batch_per_file,
reader_creator,
feeder,
......@@ -57,13 +60,16 @@ def convert_reader_to_recordio_files(
feed_order=None):
if feed_order is None:
feed_order = feeder.feed_names
f_name, f_ext = os.path.splitext(filename)
assert (f_ext == ".recordio")
lines = []
f_idx = 0
counter = 0
for idx, batch in enumerate(reader_creator()):
lines.append(batch)
if idx >= batch_per_file and idx % batch_per_file == 0:
filename = "%s-%05d" % (filename_suffix, f_idx)
filename = "%s-%05d%s" % (f_name, f_idx, f_ext)
with create_recordio_writer(filename, compressor,
max_num_records) as writer:
for l in lines:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册