提交 72be7a61 编写于 作者: Y Yu Yang

Complete RecordIO reader op

上级 bcb80756
...@@ -2,4 +2,10 @@ cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_regist ...@@ -2,4 +2,10 @@ cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_regist
op_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc DEPS reader_op_registry) op_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc DEPS reader_op_registry)
op_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc DEPS reader_op_registry) op_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc DEPS reader_op_registry)
op_library(create_batch_reader_op SRCS create_batch_reader_op.cc DEPS reader_op_registry) op_library(create_batch_reader_op SRCS create_batch_reader_op.cc DEPS reader_op_registry)
set(READER_LIBRARY create_random_data_generator_op create_shuffle_reader_op create_batch_reader_op PARENT_SCOPE) op_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc DEPS reader_op_registry)
set(READER_LIBRARY
create_recordio_file_reader_op
create_random_data_generator_op
create_shuffle_reader_op
create_batch_reader_op
PARENT_SCOPE)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/reader_op_registry.h"
#include "paddle/fluid/recordio/scanner.h"
namespace paddle {
namespace operators {
namespace reader {
class RecordIOFileReader : public framework::FileReader {
public:
RecordIOFileReader(const std::string& filename,
const std::vector<framework::DDim>& shapes)
: FileReader(shapes),
scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
}
bool HasNext() const override { return scanner_.HasNext(); }
void ReInit() override { scanner_.Reset(); }
private:
recordio::Scanner scanner_;
const platform::DeviceContext& dev_ctx_;
};
class CreateRecordIOReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
std::string filename = Attr<std::string>("filename");
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader(filename, shapes));
}
};
class CreateRecordIOReaderOpMaker : public FileReaderMakerBase {
public:
CreateRecordIOReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: FileReaderMakerBase(op_proto, op_checker) {
AddAttr<std::string>("filename", "The filename of record io reader");
AddComment(R"DOC(
CreateRecordIOReader Operator
Create a reader from a record io file
)DOC");
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace reader = paddle::operators::reader;
REGISTER_FILE_READER_OPERATOR(create_recordio_file_reader,
reader::CreateRecordIOReaderOp,
reader::CreateRecordIOReaderOpMaker);
...@@ -35,7 +35,7 @@ FileReaderMakerBase::FileReaderMakerBase( ...@@ -35,7 +35,7 @@ FileReaderMakerBase::FileReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto, framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) { : OpProtoAndCheckerMaker(op_proto, op_checker) {
AddOutput("Out", "(ReaderHolder) The created random reader."); AddOutput("Out", "(ReaderHolder) The created random reader.").AsDuplicable();
AddAttr<std::vector<int>>("shape_concat", "The concat of all data's shapes."); AddAttr<std::vector<int>>("shape_concat", "The concat of all data's shapes.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ranks", "ranks",
......
if(WITH_PYTHON) if(WITH_PYTHON)
cc_library(paddle_pybind SHARED cc_library(paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc const_value.cc SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc
DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method
${GLOB_OP_LIB}) ${GLOB_OP_LIB})
if(NOT APPLE AND NOT ANDROID) if(NOT APPLE AND NOT ANDROID)
......
...@@ -35,7 +35,9 @@ limitations under the License. */ ...@@ -35,7 +35,9 @@ limitations under the License. */
#include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/pybind/recordio.h"
#include "paddle/fluid/pybind/tensor_py.h" #include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/fluid/string/to_string.h" #include "paddle/fluid/string/to_string.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -474,6 +476,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -474,6 +476,8 @@ All parameter, weight, gradient are variables in Paddle.
m.def("enable_profiler", platform::EnableProfiler); m.def("enable_profiler", platform::EnableProfiler);
m.def("disable_profiler", platform::DisableProfiler); m.def("disable_profiler", platform::DisableProfiler);
m.def("reset_profiler", platform::ResetProfiler); m.def("reset_profiler", platform::ResetProfiler);
BindRecordIOWriter(m);
return m.ptr(); return m.ptr();
} }
} // namespace pybind } // namespace pybind
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/pybind/recordio.h"
#include <fstream>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/writer.h"
namespace paddle {
namespace pybind {
class RecordIOWriter {
public:
RecordIOWriter(const std::string& filename, recordio::Compressor compressor,
size_t max_num_record)
: stream_(filename), writer_(&stream_, compressor, max_num_record) {}
void AppendTensor(const framework::LoDTensor& tensor) {
tensors_.push_back(tensor);
}
void CompleteAppendTensor() {
auto& ctx =
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
framework::WriteToRecordIO(writer_, tensors_, ctx);
tensors_.clear();
}
void Close() {
PADDLE_ENFORCE(tensors_.empty());
writer_.Flush();
stream_.close();
}
private:
std::vector<framework::LoDTensor> tensors_;
std::ofstream stream_;
recordio::Writer writer_;
};
void BindRecordIOWriter(py::module& m) {
py::class_<RecordIOWriter> writer(m, "RecordIOWriter", "");
py::enum_<recordio::Compressor>(writer, "Compressor", "")
.value("Snappy", recordio::Compressor::kSnappy)
.value("NoCompress", recordio::Compressor::kNoCompress);
writer
.def("__init__",
[](RecordIOWriter& self, const std::string& filename,
recordio::Compressor compressor, size_t max_num_record) {
new (&self) RecordIOWriter(filename, compressor, max_num_record);
})
.def("append_tensor", &RecordIOWriter::AppendTensor)
.def("complete_append_tensor", &RecordIOWriter::CompleteAppendTensor)
.def("close", &RecordIOWriter::Close);
}
} // namespace pybind
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
extern void BindRecordIOWriter(py::module& m);
} // namespace pybind
} // namespace paddle
...@@ -25,32 +25,36 @@ namespace recordio { ...@@ -25,32 +25,36 @@ namespace recordio {
constexpr size_t kMaxBufSize = 1024; constexpr size_t kMaxBufSize = 1024;
template <typename Callback> template <typename Callback>
static void ReadStreamByBuf(std::istream& in, int limit, Callback callback) { static void ReadStreamByBuf(std::istream& in, size_t limit, Callback callback) {
char buf[kMaxBufSize]; char buf[kMaxBufSize];
std::streamsize actual_size; std::streamsize actual_size;
size_t counter = 0; size_t counter = 0;
do { size_t actual_max;
auto actual_max = while (!in.eof() || (limit != 0 && counter >= limit)) {
limit > 0 ? std::min(limit - counter, kMaxBufSize) : kMaxBufSize; actual_max =
actual_size = in.readsome(buf, actual_max); limit != 0 ? std::min(limit - counter, kMaxBufSize) : kMaxBufSize;
in.read(buf, actual_max);
actual_size = in.gcount();
if (actual_size == 0) { if (actual_size == 0) {
break; break;
} }
callback(buf, actual_size); callback(buf, actual_size);
if (limit > 0) { if (limit != 0) {
counter += actual_size; counter += actual_size;
} }
} while (actual_size == kMaxBufSize); }
in.clear(); // unset eof state
} }
static void PipeStream(std::istream& in, std::ostream& os) { static void PipeStream(std::istream& in, std::ostream& os) {
ReadStreamByBuf( ReadStreamByBuf(
in, -1, [&os](const char* buf, size_t len) { os.write(buf, len); }); in, 0, [&os](const char* buf, size_t len) { os.write(buf, len); });
} }
static uint32_t Crc32Stream(std::istream& in, int limit = -1) { static uint32_t Crc32Stream(std::istream& in, size_t limit = 0) {
auto crc = crc32(0, nullptr, 0); uint32_t crc = static_cast<uint32_t>(crc32(0, nullptr, 0));
ReadStreamByBuf(in, limit, [&crc](const char* buf, size_t len) { ReadStreamByBuf(in, limit, [&crc](const char* buf, size_t len) {
crc = crc32(crc, reinterpret_cast<const Bytef*>(buf), len); crc = static_cast<uint32_t>(crc32(
crc, reinterpret_cast<const Bytef*>(buf), static_cast<uInt>(len)));
}); });
return crc; return crc;
} }
...@@ -85,14 +89,12 @@ bool Chunk::Write(std::ostream& os, Compressor ct) const { ...@@ -85,14 +89,12 @@ bool Chunk::Write(std::ostream& os, Compressor ct) const {
compressed_stream.reset(); compressed_stream.reset();
} }
auto end_pos = sout.tellg(); uint32_t len = static_cast<uint32_t>(sout.str().size());
sout.seekg(0, std::ios::beg);
uint32_t len = static_cast<uint32_t>(end_pos - sout.tellg());
uint32_t crc = Crc32Stream(sout); uint32_t crc = Crc32Stream(sout);
sout.seekg(0, std::ios::beg);
Header hdr(static_cast<uint32_t>(records_.size()), crc, ct, len); Header hdr(static_cast<uint32_t>(records_.size()), crc, ct, len);
hdr.Write(os); hdr.Write(os);
sout.seekg(0, std::ios::beg);
sout.clear();
PipeStream(sout, os); PipeStream(sout, os);
return true; return true;
} }
...@@ -104,12 +106,10 @@ bool Chunk::Parse(std::istream& sin) { ...@@ -104,12 +106,10 @@ bool Chunk::Parse(std::istream& sin) {
return ok; return ok;
} }
auto beg_pos = sin.tellg(); auto beg_pos = sin.tellg();
auto crc = Crc32Stream(sin, hdr.CompressSize()); uint32_t crc = Crc32Stream(sin, hdr.CompressSize());
PADDLE_ENFORCE_EQ(hdr.Checksum(), crc); PADDLE_ENFORCE_EQ(hdr.Checksum(), crc);
Clear(); Clear();
sin.seekg(beg_pos, sin.beg);
sin.seekg(beg_pos, std::ios::beg);
std::unique_ptr<std::istream> compressed_stream; std::unique_ptr<std::istream> compressed_stream;
switch (hdr.CompressType()) { switch (hdr.CompressType()) {
case Compressor::kNoCompress: case Compressor::kNoCompress:
......
...@@ -52,8 +52,8 @@ void Header::Write(std::ostream& os) const { ...@@ -52,8 +52,8 @@ void Header::Write(std::ostream& os) const {
} }
std::ostream& operator<<(std::ostream& os, Header h) { std::ostream& operator<<(std::ostream& os, Header h) {
os << h.NumRecords() << h.Checksum() os << "Header: " << h.NumRecords() << ", " << h.Checksum() << ", "
<< static_cast<uint32_t>(h.CompressType()) << h.CompressSize(); << static_cast<uint32_t>(h.CompressType()) << ", " << h.CompressSize();
return os; return os;
} }
......
...@@ -39,6 +39,7 @@ import clip ...@@ -39,6 +39,7 @@ import clip
from memory_optimization_transpiler import memory_optimize from memory_optimization_transpiler import memory_optimize
import profiler import profiler
import unique_name import unique_name
import recordio_writer
Tensor = LoDTensor Tensor = LoDTensor
...@@ -64,6 +65,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [ ...@@ -64,6 +65,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [
'memory_optimize', 'memory_optimize',
'profiler', 'profiler',
'unique_name', 'unique_name',
'recordio_writer',
] ]
......
...@@ -13,11 +13,15 @@ ...@@ -13,11 +13,15 @@
# limitations under the License. # limitations under the License.
from .. import core from .. import core
from ..layer_helper import LayerHelper from ..framework import convert_np_dtype_to_dtype_, default_main_program, default_startup_program
from ..unique_name import generate as unique_name
from control_flow import BlockGuard from control_flow import BlockGuard
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
__all__ = ['data', 'BlockGuardServ', 'ListenAndServ', 'Send'] __all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'read_file'
]
def data(name, def data(name,
...@@ -224,3 +228,55 @@ def Recv(endpoints, get_vars): ...@@ -224,3 +228,55 @@ def Recv(endpoints, get_vars):
outputs={"Out": get_vars}, outputs={"Out": get_vars},
attrs={"endpoints": endpoints, attrs={"endpoints": endpoints,
"epmap": epmap}) "epmap": epmap})
def _copy_reader_var_(block, var):
new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
new_var.desc.set_shapes(var.desc.shapes())
new_var.desc.set_dtypes(var.desc.dtypes())
new_var.persistable = True
return new_var
def open_recordio_file(filename, shapes, lod_levels, dtypes):
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []
for shape in shapes:
shape_concat.extend(shape)
ranks.append(len(shape))
var_name = unique_name('open_recordio_file')
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
startup_blk.append_op(
type='create_recordio_file_reader',
outputs={'Out': [startup_var]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'filename': filename,
'ranks': ranks
})
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(),
startup_var)
def read_file(file_obj):
helper = LayerHelper('read_file')
out = [
helper.create_tmp_variable(
stop_gradient=True, dtype='float32')
for i in range(len(file_obj.desc.shapes()))
]
helper.append_op(
type='read', inputs={'Reader': [file_obj]}, outputs={'Out': out})
if len(out) == 1:
return out[0]
else:
return out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import core
class RecordIOWriter(object):
def __init__(self,
filename,
compressor=core.RecordIOWriter.Compressor.Snappy,
max_num_records=1000):
self.filename = filename
self.compressor = compressor
self.max_num_records = max_num_records
self.writer = None
def __enter__(self):
self.writer = core.RecordIOWriter(self.filename, self.compressor,
self.max_num_records)
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
return False
else:
self.writer.close()
def append_tensor(self, tensor):
self.writer.append_tensor(tensor)
def complete_append_tensor(self):
self.writer.complete_append_tensor()
def convert_reader_to_recordio_file(
filename,
reader_creator,
feeder,
compressor=core.RecordIOWriter.Compressor.Snappy,
max_num_records=1000,
feed_order=None):
writer = RecordIOWriter(filename, compressor, max_num_records)
with writer:
for batch in reader_creator():
res = feeder.feed(batch)
if feed_order is None:
for each in res:
writer.append_tensor(res[each])
else:
for each in feed_order:
writer.append_tensor(res[each])
writer.complete_append_tensor()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import paddle.v2.dataset.mnist as mnist
import paddle.v2 as paddle
class TestRecordIO(unittest.TestCase):
def setUp(self):
with fluid.program_guard(fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=32)
feeder = fluid.DataFeeder(
feed_list=[
fluid.layers.data(
name='image', shape=[784]), fluid.layers.data(
name='label', shape=[1], dtype='int64')
],
place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist.recordio',
reader,
feeder,
feed_order=['image', 'label'])
def testMain(self):
data_file = fluid.layers.open_recordio_file(
'./mnist.recordio',
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
img, label = fluid.layers.read_file(data_file)
hidden = fluid.layers.fc(input=img, size=100, act='tanh')
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
fluid.optimizer.SGD(learning_rate=1e-3).minimize(avg_loss)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
avg_loss_np, = exe.run(fetch_list=[avg_loss])
print avg_loss_np
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册