diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 48713f2c2ac62a37b7b7a4602f7f6a325aecb0b8..15e5574ecfd406b87db8370948352b7e736937ea 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -21,7 +21,7 @@ endif() cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) nv_test(mixed_vector_test SRCS mixed_vector_test.cu DEPS place paddle_memory device_context init) -cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto) +cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor init) diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index e2f4e9cad1996578b7c51257785e1273d126f80f..8155cb55a468a09320b1196b49fc3e34cea261b1 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -19,6 +19,9 @@ limitations under the License. */ #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/recordio/scanner.h" +#include "paddle/fluid/recordio/writer.h" + #include #include #include @@ -291,6 +294,31 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, TensorFromStream(is, static_cast(tensor), dev_ctx); } +void WriteToRecordIO(recordio::Writer &writer, + const std::vector &tensor, + const platform::DeviceContext &dev_ctx) { + std::stringstream buffer; + size_t sz = tensor.size(); + buffer.write(reinterpret_cast(&sz), sizeof(uint32_t)); + for (auto &each : tensor) { + SerializeToStream(buffer, each, dev_ctx); + } + writer.Write(buffer.str()); +} + +std::vector ReadFromRecordIO( + recordio::Scanner &scanner, const platform::DeviceContext &dev_ctx) { + std::istringstream sin(scanner.Next()); + uint32_t sz; + sin.read(reinterpret_cast(&sz), sizeof(uint32_t)); + std::vector result; + result.resize(sz); + for (uint32_t i = 0; i < sz; ++i) { + DeserializeFromStream(sin, &result[i], dev_ctx); + } + return result; +} + std::vector LoDTensor::SplitLoDTensor( const std::vector places) const { check_memory_size(); diff --git a/paddle/fluid/framework/lod_tensor.h b/paddle/fluid/framework/lod_tensor.h index 94d5a6e9fd9b68d3d8230a8c258316efadda5a05..dee505fee0dccd8d60bb290a8bec4df243e504a2 100644 --- a/paddle/fluid/framework/lod_tensor.h +++ b/paddle/fluid/framework/lod_tensor.h @@ -29,6 +29,12 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" namespace paddle { + +namespace recordio { +class Writer; +class Scanner; +} + namespace framework { /* @@ -209,5 +215,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor, void DeserializeFromStream(std::istream& is, LoDTensor* tensor, const platform::DeviceContext& dev_ctx); +extern void WriteToRecordIO(recordio::Writer& writer, + const std::vector& tensor, + const platform::DeviceContext& dev_ctx); + +extern std::vector ReadFromRecordIO( + recordio::Scanner& scanner, const platform::DeviceContext& dev_ctx); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/lod_tensor_test.cc b/paddle/fluid/framework/lod_tensor_test.cc index 5e135192ce774ab5c351b89164be9d7600ae3640..e691e29383d4842b80769021e0e494967d38e9bb 100644 --- a/paddle/fluid/framework/lod_tensor_test.cc +++ b/paddle/fluid/framework/lod_tensor_test.cc @@ -14,6 +14,9 @@ #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/recordio/scanner.h" +#include "paddle/fluid/recordio/writer.h" + #include #include #include @@ -224,5 +227,43 @@ TEST(LoD, CheckAbsLoD) { abs_lod0.push_back(std::vector({0})); ASSERT_FALSE(CheckAbsLoD(abs_lod0)); } + +TEST(LoDTensor, RecordIO) { + LoDTensor tensor; + int* tmp = tensor.mutable_data(make_ddim({4, 5}), platform::CPUPlace()); + for (int i = 0; i < 20; ++i) { + tmp[i] = i; + } + + std::stringstream* stream = new std::stringstream(); + auto& ctx = + *platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); + { + recordio::Writer writer(stream, recordio::Compressor::kSnappy); + WriteToRecordIO(writer, {tensor, tensor}, ctx); + WriteToRecordIO(writer, {tensor, tensor}, ctx); + writer.Flush(); + } + + auto assert_tensor_ok = [](const LoDTensor& tensor) { + for (int i = 0; i < 20; ++i) { + ASSERT_EQ(tensor.data()[i], i); + } + }; + + { + std::unique_ptr stream_ptr(stream); + recordio::Scanner scanner(std::move(stream_ptr)); + auto tensors = ReadFromRecordIO(scanner, ctx); + ASSERT_EQ(tensors.size(), 2); + assert_tensor_ok(tensors[0]); + assert_tensor_ok(tensors[1]); + tensors = ReadFromRecordIO(scanner, ctx); + ASSERT_EQ(tensors.size(), 2); + assert_tensor_ok(tensors[0]); + assert_tensor_ok(tensors[1]); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index e820c3d07e85fd1dea9080786b48ad031330ee00..18064ddc669aad7dda98d502119e56e7ddedcff3 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -33,6 +33,8 @@ class ReaderBase { std::vector shapes() const { return shapes_; } void set_shapes(const std::vector& shapes) { shapes_ = shapes; } + virtual bool HasNext() const = 0; + virtual ~ReaderBase() {} protected: @@ -53,6 +55,8 @@ class DecoratedReader : public ReaderBase { void ReInit() override { reader_->ReInit(); } + bool HasNext() const override { return reader_->HasNext(); } + protected: ReaderBase* reader_; }; @@ -87,6 +91,8 @@ class ReaderHolder { reader_->set_shapes(shapes); } + bool HasNext() const { return reader_->HasNext(); } + private: std::unique_ptr reader_; }; diff --git a/paddle/fluid/operators/detail/safe_ref.h b/paddle/fluid/operators/detail/safe_ref.h index 9cb5851deba6b16261d4499afcfb867d9d706498..48bdce740878ea486eda6821dc29885a3e480114 100644 --- a/paddle/fluid/operators/detail/safe_ref.h +++ b/paddle/fluid/operators/detail/safe_ref.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/platform/enforce.h" + namespace paddle { namespace operators { namespace detail { diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 335c5b26a864381bf87a2824b78f521cdce063e4..744bd3b7ef71f83ad82979eb966369c2e9456a7d 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -1,6 +1,24 @@ cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_registry reader) -op_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc DEPS reader_op_registry) -op_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc DEPS reader_op_registry) -op_library(create_batch_reader_op SRCS create_batch_reader_op.cc DEPS reader_op_registry) -op_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc DEPS reader_op_registry) -set(READER_LIBRARY create_random_data_generator_op create_shuffle_reader_op create_batch_reader_op create_double_buffer_reader_op PARENT_SCOPE) +set(LOCAL_READER_LIBS) + +function(reader_library TARGET_NAME) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + set(options "") + set(common_deps reader_op_registry) + cmake_parse_arguments(reader_library "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + op_library(${TARGET_NAME} SRCS ${reader_library_SRCS} DEPS ${common_deps} ${reader_library_DEPS}) + set(LOCAL_READER_LIBS + ${TARGET_NAME} + ${LOCAL_READER_LIBS} + PARENT_SCOPE) +endfunction() + +reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) +reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) +reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) +reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) +reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) +# Export local libraries to parent +set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index b6a0609a1e23195ececee0f16a69daa1c1c46ed8..ba08ea12e2486aaba8c57a9fe23592bd1738592d 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -37,6 +37,8 @@ class DoubleBufferReader : public framework::DecoratedReader { ~DoubleBufferReader() { buffer_->Close(); } + bool HasNext() const override; + private: void PrefetchThreadFunc(); @@ -106,6 +108,8 @@ void DoubleBufferReader::PrefetchThreadFunc() { } } +bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); } + } // namespace reader } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index 73c39b5da4484b27f75aeba3c8171c5ffed2398f..e62f952d0e89561c3eed56112dc9d1d78801b59e 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -52,6 +52,8 @@ class RandomDataGenerator : public framework::FileReader { void ReInit() override { return; } + bool HasNext() const override { return true; } + private: float min_; float max_; diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c3eb247bbe2041ae5a673c4fd3c1284c71276f91 --- /dev/null +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reader/reader_op_registry.h" +#include "paddle/fluid/recordio/scanner.h" + +namespace paddle { +namespace operators { +namespace reader { +class RecordIOFileReader : public framework::FileReader { + public: + RecordIOFileReader(const std::string& filename, + const std::vector& shapes) + : FileReader(shapes), + scanner_(filename), + dev_ctx_(*platform::DeviceContextPool::Instance().Get( + platform::CPUPlace())) {} + + void ReadNext(std::vector* out) override { + *out = framework::ReadFromRecordIO(scanner_, dev_ctx_); + } + + bool HasNext() const override { return scanner_.HasNext(); } + + void ReInit() override { scanner_.Reset(); } + + private: + recordio::Scanner scanner_; + const platform::DeviceContext& dev_ctx_; +}; + +class CreateRecordIOReaderOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const auto& shape_concat = Attr>("shape_concat"); + const auto& ranks = Attr>("ranks"); + PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); + PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), + int(shape_concat.size()), + "The accumulate of all ranks should be equal to the " + "shape concat's length."); + std::vector shapes = RestoreShapes(shape_concat, ranks); + std::string filename = Attr("filename"); + + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable(); + out->Reset(new RecordIOFileReader(filename, shapes)); + } +}; + +class CreateRecordIOReaderOpMaker : public FileReaderMakerBase { + public: + CreateRecordIOReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : FileReaderMakerBase(op_proto, op_checker) { + AddAttr("filename", "The filename of record io reader"); + AddComment(R"DOC( + CreateRecordIOReader Operator + + Create a reader from a record io file + )DOC"); + } +}; + +} // namespace reader +} // namespace operators +} // namespace paddle + +namespace reader = paddle::operators::reader; + +REGISTER_FILE_READER_OPERATOR(create_recordio_file_reader, + reader::CreateRecordIOReaderOp, + reader::CreateRecordIOReaderOpMaker); diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index f80769d7cd2d35261cd55fc1d6c8c20197f5e88c..33d4ff4099a509daeaab83032c5d382718904dc7 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -35,7 +35,7 @@ FileReaderMakerBase::FileReaderMakerBase( framework::OpProtoAndCheckerMaker::OpProto* op_proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(op_proto, op_checker) { - AddOutput("Out", "(ReaderHolder) The created random reader."); + AddOutput("Out", "(ReaderHolder) The created random reader.").AsDuplicable(); AddAttr>("shape_concat", "The concat of all data's shapes."); AddAttr>( "ranks", diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index d62f34030894e2fa21925bbc44e24b4e7d738d15..8942b5c9430ffa4e499b0ad1d2b5acf6d18ec0ab 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,6 +1,6 @@ if(WITH_PYTHON) cc_library(paddle_pybind SHARED - SRCS pybind.cc exception.cc protobuf.cc const_value.cc + SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method ${GLOB_OP_LIB}) if(NOT APPLE AND NOT ANDROID) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index ac7d1efb577505b70e10a70cdcfd3ed9c5fe1f5c..d2e883caccdd34a9d662f06b83cf9a71d3d4a51e 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/prune.h" +#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/operators/cond_op.h" #include "paddle/fluid/operators/net_op.h" @@ -35,7 +36,9 @@ limitations under the License. */ #include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/pybind.h" +#include "paddle/fluid/pybind/recordio.h" #include "paddle/fluid/pybind/tensor_py.h" + #include "paddle/fluid/string/to_string.h" #ifdef PADDLE_WITH_CUDA @@ -217,8 +220,18 @@ All parameter, weight, gradient are variables in Paddle. [](Variable &self) -> operators::NetOp * { return self.GetMutable(); }, + py::return_value_policy::reference) + .def("get_reader", + [](Variable &self) -> framework::ReaderHolder * { + PADDLE_ENFORCE(self.IsType()); + return self.GetMutable(); + }, py::return_value_policy::reference); + py::class_(m, "Reader", "") + .def("has_next", &framework::ReaderHolder::HasNext) + .def("reset", &framework::ReaderHolder::ReInit); + py::class_(m, "Scope", "") .def("var", [](Scope &self, const std::string &name) -> Variable * { @@ -474,6 +487,8 @@ All parameter, weight, gradient are variables in Paddle. m.def("enable_profiler", platform::EnableProfiler); m.def("disable_profiler", platform::DisableProfiler); m.def("reset_profiler", platform::ResetProfiler); + + BindRecordIOWriter(m); return m.ptr(); } } // namespace pybind diff --git a/paddle/fluid/pybind/recordio.cc b/paddle/fluid/pybind/recordio.cc new file mode 100644 index 0000000000000000000000000000000000000000..16f8bfb1a2e3a840670594d3cc2970e690dce891 --- /dev/null +++ b/paddle/fluid/pybind/recordio.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pybind/recordio.h" +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/recordio/writer.h" + +namespace paddle { +namespace pybind { + +class RecordIOWriter { + public: + RecordIOWriter(const std::string& filename, recordio::Compressor compressor, + size_t max_num_record) + : stream_(filename), writer_(&stream_, compressor, max_num_record) {} + + void AppendTensor(const framework::LoDTensor& tensor) { + tensors_.push_back(tensor); + } + + void CompleteAppendTensor() { + auto& ctx = + *platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); + framework::WriteToRecordIO(writer_, tensors_, ctx); + tensors_.clear(); + } + + void Close() { + PADDLE_ENFORCE(tensors_.empty()); + writer_.Flush(); + stream_.close(); + } + + private: + std::vector tensors_; + std::ofstream stream_; + recordio::Writer writer_; +}; + +void BindRecordIOWriter(py::module& m) { + py::class_ writer(m, "RecordIOWriter", ""); + py::enum_(writer, "Compressor", "") + .value("Snappy", recordio::Compressor::kSnappy) + .value("NoCompress", recordio::Compressor::kNoCompress); + + writer + .def("__init__", + [](RecordIOWriter& self, const std::string& filename, + recordio::Compressor compressor, size_t max_num_record) { + new (&self) RecordIOWriter(filename, compressor, max_num_record); + }) + .def("append_tensor", &RecordIOWriter::AppendTensor) + .def("complete_append_tensor", &RecordIOWriter::CompleteAppendTensor) + .def("close", &RecordIOWriter::Close); +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/recordio.h b/paddle/fluid/pybind/recordio.h new file mode 100644 index 0000000000000000000000000000000000000000..60e6a9e8595614b38375fca8c13d520739af9aaf --- /dev/null +++ b/paddle/fluid/pybind/recordio.h @@ -0,0 +1,26 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { + +extern void BindRecordIOWriter(py::module& m); +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/recordio/CMakeLists.txt b/paddle/fluid/recordio/CMakeLists.txt index e1e7c2cdb3d0c960d5cd408420b5aaead73e70d7..92e97a6c85d7c8f01c8473feb9772f2285d49673 100644 --- a/paddle/fluid/recordio/CMakeLists.txt +++ b/paddle/fluid/recordio/CMakeLists.txt @@ -3,4 +3,7 @@ cc_library(header SRCS header.cc) cc_test(header_test SRCS header_test.cc DEPS header) cc_library(chunk SRCS chunk.cc DEPS snappystream snappy header zlib) cc_test(chunk_test SRCS chunk_test.cc DEPS chunk) -cc_library(recordio DEPS chunk header) +cc_library(writer SRCS writer.cc DEPS chunk) +cc_library(scanner SRCS scanner.cc DEPS chunk) +cc_test(writer_scanner_test SRCS writer_scanner_test.cc DEPS writer scanner) +cc_library(recordio DEPS chunk header writer scanner) diff --git a/paddle/fluid/recordio/chunk.cc b/paddle/fluid/recordio/chunk.cc index 587fd375c38ca83e1c65cb3ccc20b3509b6348c7..187a6a4ea7bd9d3a8ae48fa262e18f71b0f7d20d 100644 --- a/paddle/fluid/recordio/chunk.cc +++ b/paddle/fluid/recordio/chunk.cc @@ -24,33 +24,52 @@ namespace paddle { namespace recordio { constexpr size_t kMaxBufSize = 1024; +/** + * Read Stream by a fixed sized buffer. + * @param in input stream + * @param limit read at most `limit` bytes from input stream. 0 means no limit + * @param callback A function object with (const char* buf, size_t size) -> void + * as its type. + */ template -static void ReadStreamByBuf(std::istream& in, int limit, Callback callback) { +static void ReadStreamByBuf(std::istream& in, size_t limit, Callback callback) { char buf[kMaxBufSize]; std::streamsize actual_size; size_t counter = 0; - do { - auto actual_max = - limit > 0 ? std::min(limit - counter, kMaxBufSize) : kMaxBufSize; - actual_size = in.readsome(buf, actual_max); + size_t actual_max; + while (!in.eof() || + (limit != 0 && counter >= limit)) { // End of file or reach limit + actual_max = + limit != 0 ? std::min(limit - counter, kMaxBufSize) : kMaxBufSize; + in.read(buf, actual_max); + actual_size = in.gcount(); if (actual_size == 0) { break; } callback(buf, actual_size); - if (limit > 0) { + if (limit != 0) { counter += actual_size; } - } while (actual_size == kMaxBufSize); + } + in.clear(); // unset eof state } +/** + * Copy stream in to another stream + */ static void PipeStream(std::istream& in, std::ostream& os) { ReadStreamByBuf( - in, -1, [&os](const char* buf, size_t len) { os.write(buf, len); }); + in, 0, [&os](const char* buf, size_t len) { os.write(buf, len); }); } -static uint32_t Crc32Stream(std::istream& in, int limit = -1) { - auto crc = crc32(0, nullptr, 0); + +/** + * Calculate CRC32 from an input stream. + */ +static uint32_t Crc32Stream(std::istream& in, size_t limit = 0) { + uint32_t crc = static_cast(crc32(0, nullptr, 0)); ReadStreamByBuf(in, limit, [&crc](const char* buf, size_t len) { - crc = crc32(crc, reinterpret_cast(buf), len); + crc = static_cast(crc32( + crc, reinterpret_cast(buf), static_cast(len))); }); return crc; } @@ -85,28 +104,29 @@ bool Chunk::Write(std::ostream& os, Compressor ct) const { compressed_stream.reset(); } - auto end_pos = sout.tellg(); + sout.seekg(0, std::ios::end); + uint32_t len = static_cast(sout.tellg()); sout.seekg(0, std::ios::beg); - uint32_t len = static_cast(end_pos - sout.tellg()); uint32_t crc = Crc32Stream(sout); - sout.seekg(0, std::ios::beg); - Header hdr(static_cast(records_.size()), crc, ct, len); hdr.Write(os); + sout.seekg(0, std::ios::beg); + sout.clear(); PipeStream(sout, os); return true; } -void Chunk::Parse(std::istream& sin) { +bool Chunk::Parse(std::istream& sin) { Header hdr; - hdr.Parse(sin); + bool ok = hdr.Parse(sin); + if (!ok) { + return ok; + } auto beg_pos = sin.tellg(); - auto crc = Crc32Stream(sin, hdr.CompressSize()); + uint32_t crc = Crc32Stream(sin, hdr.CompressSize()); PADDLE_ENFORCE_EQ(hdr.Checksum(), crc); - Clear(); - - sin.seekg(beg_pos, std::ios::beg); + sin.seekg(beg_pos, sin.beg); std::unique_ptr compressed_stream; switch (hdr.CompressType()) { case Compressor::kNoCompress: @@ -126,8 +146,10 @@ void Chunk::Parse(std::istream& sin) { std::string buf; buf.resize(rec_len); stream.read(&buf[0], rec_len); + PADDLE_ENFORCE_EQ(rec_len, stream.gcount()); Add(buf); } + return true; } } // namespace recordio diff --git a/paddle/fluid/recordio/chunk.h b/paddle/fluid/recordio/chunk.h index 0ba9c63abbe72e7a51ddb1af5f0d206aa9f6cc5b..bf20ebd455c26ddeebeeea8db04cf7103b0c085f 100644 --- a/paddle/fluid/recordio/chunk.h +++ b/paddle/fluid/recordio/chunk.h @@ -26,9 +26,9 @@ namespace recordio { class Chunk { public: Chunk() : num_bytes_(0) {} - void Add(std::string buf) { - records_.push_back(buf); + void Add(const std::string& buf) { num_bytes_ += buf.size(); + records_.emplace_back(buf); } // dump the chunk into w, and clears the chunk and makes it ready for // the next add invocation. @@ -37,10 +37,15 @@ public: records_.clear(); num_bytes_ = 0; } - void Parse(std::istream& sin); - size_t NumBytes() { return num_bytes_; } + + // returns true if ok, false if eof + bool Parse(std::istream& sin); + size_t NumBytes() const { return num_bytes_; } + size_t NumRecords() const { return records_.size(); } const std::string& Record(int i) const { return records_[i]; } + bool Empty() const { return records_.empty(); } + private: std::vector records_; // sum of record lengths in bytes. diff --git a/paddle/fluid/recordio/chunk_test.cc b/paddle/fluid/recordio/chunk_test.cc index a67ba32ed6ab8bda230d1414975c96a0be6d682b..1f0e36a14d373ca96167199d4582bc8f17290ae8 100644 --- a/paddle/fluid/recordio/chunk_test.cc +++ b/paddle/fluid/recordio/chunk_test.cc @@ -26,7 +26,7 @@ TEST(Chunk, SaveLoad) { ch.Add(std::string("123", 4)); std::stringstream ss; ch.Write(ss, Compressor::kNoCompress); - ch.Clear(); + ss.seekg(0); ch.Parse(ss); ASSERT_EQ(ch.NumBytes(), 10U); } diff --git a/paddle/fluid/recordio/header.cc b/paddle/fluid/recordio/header.cc index 3641caaa8981020519cbc31e5362348c02d3bbce..e50de15b7c2b480357f5f6c7daa2b4a676749679 100644 --- a/paddle/fluid/recordio/header.cc +++ b/paddle/fluid/recordio/header.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/recordio/header.h" +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace recordio { @@ -26,23 +27,33 @@ Header::Header() Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs) : num_records_(num), checksum_(sum), compressor_(c), compress_size_(cs) {} -void Header::Parse(std::istream& is) { +bool Header::Parse(std::istream& is) { + uint32_t magic; + size_t read_size = + is.readsome(reinterpret_cast(&magic), sizeof(uint32_t)); + if (read_size < sizeof(uint32_t)) { + return false; + } + PADDLE_ENFORCE_EQ(magic, kMagicNumber); + is.read(reinterpret_cast(&num_records_), sizeof(uint32_t)) .read(reinterpret_cast(&checksum_), sizeof(uint32_t)) .read(reinterpret_cast(&compressor_), sizeof(uint32_t)) .read(reinterpret_cast(&compress_size_), sizeof(uint32_t)); + return true; } void Header::Write(std::ostream& os) const { - os.write(reinterpret_cast(&num_records_), sizeof(uint32_t)) + os.write(reinterpret_cast(&kMagicNumber), sizeof(uint32_t)) + .write(reinterpret_cast(&num_records_), sizeof(uint32_t)) .write(reinterpret_cast(&checksum_), sizeof(uint32_t)) .write(reinterpret_cast(&compressor_), sizeof(uint32_t)) .write(reinterpret_cast(&compress_size_), sizeof(uint32_t)); } std::ostream& operator<<(std::ostream& os, Header h) { - os << h.NumRecords() << h.Checksum() - << static_cast(h.CompressType()) << h.CompressSize(); + os << "Header: " << h.NumRecords() << ", " << h.Checksum() << ", " + << static_cast(h.CompressType()) << ", " << h.CompressSize(); return os; } diff --git a/paddle/fluid/recordio/header.h b/paddle/fluid/recordio/header.h index cbd52642a668d1eaeeafb672e50af1a476975080..9200ac090de4514bef3704ac502039222eef2284 100644 --- a/paddle/fluid/recordio/header.h +++ b/paddle/fluid/recordio/header.h @@ -19,8 +19,6 @@ namespace paddle { namespace recordio { -// Default ChunkSize -constexpr size_t kDefaultMaxChunkSize = 32 * 1024 * 1024; // MagicNumber for memory checking constexpr uint32_t kMagicNumber = 0x01020304; @@ -44,7 +42,9 @@ public: Header(uint32_t num, uint32_t sum, Compressor ct, uint32_t cs); void Write(std::ostream& os) const; - void Parse(std::istream& is); + + // returns true if OK, false if eof + bool Parse(std::istream& is); uint32_t NumRecords() const { return num_records_; } uint32_t Checksum() const { return checksum_; } diff --git a/paddle/fluid/recordio/scanner.cc b/paddle/fluid/recordio/scanner.cc new file mode 100644 index 0000000000000000000000000000000000000000..d842f8fe5a4c9d1a2b564c738d97fffb02f3ccb5 --- /dev/null +++ b/paddle/fluid/recordio/scanner.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/recordio/scanner.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace recordio { +Scanner::Scanner(std::unique_ptr &&stream) + : stream_(std::move(stream)) { + Reset(); +} + +Scanner::Scanner(const std::string &filename) { + stream_.reset(new std::ifstream(filename)); + Reset(); +} + +void Scanner::Reset() { + stream_->seekg(0, std::ios::beg); + ParseNextChunk(); +} + +std::string Scanner::Next() { + PADDLE_ENFORCE(!eof_, "StopIteration"); + auto rec = cur_chunk_.Record(offset_++); + if (offset_ == cur_chunk_.NumRecords()) { + ParseNextChunk(); + } + return rec; +} + +void Scanner::ParseNextChunk() { + eof_ = !cur_chunk_.Parse(*stream_); + offset_ = 0; +} + +bool Scanner::HasNext() const { return !eof_; } +} // namespace recordio +} // namespace paddle diff --git a/paddle/fluid/recordio/scanner.h b/paddle/fluid/recordio/scanner.h new file mode 100644 index 0000000000000000000000000000000000000000..f3f17b69f195ddd92f5a39ead9755a7b8e2dd329 --- /dev/null +++ b/paddle/fluid/recordio/scanner.h @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/recordio/chunk.h" +namespace paddle { +namespace recordio { + +class Scanner { +public: + explicit Scanner(std::unique_ptr&& stream); + + explicit Scanner(const std::string& filename); + + void Reset(); + + std::string Next(); + + bool HasNext() const; + +private: + std::unique_ptr stream_; + Chunk cur_chunk_; + size_t offset_; + bool eof_; + + void ParseNextChunk(); +}; +} // namespace recordio +} // namespace paddle diff --git a/paddle/fluid/recordio/writer.cc b/paddle/fluid/recordio/writer.cc new file mode 100644 index 0000000000000000000000000000000000000000..196d66edff8cc6000afcd74fb945c05dcab7106a --- /dev/null +++ b/paddle/fluid/recordio/writer.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/recordio/writer.h" +#include "paddle/fluid/platform/enforce.h" +namespace paddle { +namespace recordio { +void Writer::Write(const std::string& record) { + cur_chunk_.Add(record); + if (cur_chunk_.NumRecords() >= max_num_records_in_chunk_) { + Flush(); + } +} + +void Writer::Flush() { + cur_chunk_.Write(stream_, compressor_); + cur_chunk_.Clear(); +} + +Writer::~Writer() { + PADDLE_ENFORCE(cur_chunk_.Empty(), "Writer must be flushed when destroy."); +} + +} // namespace recordio +} // namespace paddle diff --git a/paddle/fluid/recordio/writer.h b/paddle/fluid/recordio/writer.h new file mode 100644 index 0000000000000000000000000000000000000000..0c478d507547b10b8ebaaf5e512557a5c8c13e65 --- /dev/null +++ b/paddle/fluid/recordio/writer.h @@ -0,0 +1,43 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/recordio/chunk.h" +namespace paddle { +namespace recordio { + +class Writer { +public: + Writer(std::ostream* sout, + Compressor compressor, + size_t max_num_records_in_chunk = 1000) + : stream_(*sout), + max_num_records_in_chunk_(max_num_records_in_chunk), + compressor_(compressor) {} + + void Write(const std::string& record); + + void Flush(); + + ~Writer(); + +private: + std::ostream& stream_; + size_t max_num_records_in_chunk_; + Chunk cur_chunk_; + Compressor compressor_; +}; + +} // namespace recordio +} // namespace paddle diff --git a/paddle/fluid/recordio/writer_scanner_test.cc b/paddle/fluid/recordio/writer_scanner_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e764f0d9439709ad101af2b8864dc0158bd359b --- /dev/null +++ b/paddle/fluid/recordio/writer_scanner_test.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" + +#include +#include "paddle/fluid/recordio/scanner.h" +#include "paddle/fluid/recordio/writer.h" + +TEST(WriterScanner, Normal) { + std::stringstream* stream = new std::stringstream(); + + { + paddle::recordio::Writer writer(stream, + paddle::recordio::Compressor::kSnappy); + writer.Write("ABC"); + writer.Write("BCD"); + writer.Write("CDE"); + writer.Flush(); + } + + { + stream->seekg(0, std::ios::beg); + std::unique_ptr stream_ptr(stream); + paddle::recordio::Scanner scanner(std::move(stream_ptr)); + ASSERT_TRUE(scanner.HasNext()); + ASSERT_EQ(scanner.Next(), "ABC"); + ASSERT_EQ("BCD", scanner.Next()); + ASSERT_TRUE(scanner.HasNext()); + ASSERT_EQ("CDE", scanner.Next()); + ASSERT_FALSE(scanner.HasNext()); + } +} + +TEST(WriterScanner, TinyChunk) { + std::stringstream* stream = new std::stringstream(); + { + paddle::recordio::Writer writer( + stream, paddle::recordio::Compressor::kNoCompress, 2 /*max chunk num*/); + writer.Write("ABC"); + writer.Write("BCD"); + writer.Write("CDE"); + writer.Write("DEFG"); + writer.Flush(); + } + + { + stream->seekg(0, std::ios::beg); + std::unique_ptr stream_ptr(stream); + paddle::recordio::Scanner scanner(std::move(stream_ptr)); + ASSERT_TRUE(scanner.HasNext()); + ASSERT_EQ(scanner.Next(), "ABC"); + ASSERT_EQ(scanner.Next(), "BCD"); + ASSERT_EQ(scanner.Next(), "CDE"); + ASSERT_EQ(scanner.Next(), "DEFG"); + ASSERT_FALSE(scanner.HasNext()); + } +} \ No newline at end of file diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 84a57aff516ad2d7ba1efaf1d530e77747d3b254..dcde08632a6bb4c5936c32048c2cc1dca7608b06 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -40,6 +40,7 @@ import clip from memory_optimization_transpiler import memory_optimize, release_memory import profiler import unique_name +import recordio_writer Tensor = LoDTensor @@ -66,6 +67,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [ 'release_memory', 'profiler', 'unique_name', + 'recordio_writer', ] diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 35aa80a2ae9a6289665b581275fb86c3931fd7a8..1c0f1f6eb415b1c05c1052c1f52743a19c49f017 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -47,7 +47,7 @@ def is_parameter(var): def is_persistable(var): if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ - var.desc.type() == core.VarDesc.VarType.FETCH_LIST: + var.desc.type() == core.VarDesc.VarType.FETCH_LIST: return False return var.persistable diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index af3ae54248a744e7e2fed8190aeeb0eb481cb315..f1b2af70205ab40f08c11061a683b567f5bcbb7b 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -13,11 +13,16 @@ # limitations under the License. from .. import core -from ..layer_helper import LayerHelper +from ..framework import convert_np_dtype_to_dtype_, default_main_program, default_startup_program +from ..unique_name import generate as unique_name from control_flow import BlockGuard from ..layer_helper import LayerHelper +from ..executor import global_scope -__all__ = ['data', 'BlockGuardServ', 'ListenAndServ', 'Send'] +__all__ = [ + 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', + 'read_file' +] def data(name, @@ -224,3 +229,72 @@ def Recv(endpoints, get_vars): outputs={"Out": get_vars}, attrs={"endpoints": endpoints, "epmap": epmap}) + + +def monkey_patch_reader_methods(reader): + def __get_reader__(): + scope = global_scope() + var = scope.find_var(reader.name) + return var.get_reader() + + def eof(): + return not __get_reader__().has_next() + + def reset(): + return __get_reader__().reset() + + reader.eof = eof + reader.reset = reset + return reader + + +def _copy_reader_var_(block, var): + new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER) + new_var.desc.set_shapes(var.desc.shapes()) + new_var.desc.set_dtypes(var.desc.dtypes()) + new_var.persistable = True + return monkey_patch_reader_methods(new_var) + + +def open_recordio_file(filename, shapes, lod_levels, dtypes): + dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] + shape_concat = [] + ranks = [] + + for shape in shapes: + shape_concat.extend(shape) + ranks.append(len(shape)) + + var_name = unique_name('open_recordio_file') + + startup_blk = default_startup_program().current_block() + startup_var = startup_blk.create_var(name=var_name) + startup_blk.append_op( + type='create_recordio_file_reader', + outputs={'Out': [startup_var]}, + attrs={ + 'shape_concat': shape_concat, + 'lod_levels': lod_levels, + 'filename': filename, + 'ranks': ranks + }) + + startup_var.desc.set_dtypes(dtypes) + startup_var.persistable = True + return _copy_reader_var_(default_main_program().current_block(), + startup_var) + + +def read_file(file_obj): + helper = LayerHelper('read_file') + out = [ + helper.create_tmp_variable( + stop_gradient=True, dtype='float32') + for _ in range(len(file_obj.desc.shapes())) + ] + helper.append_op( + type='read', inputs={'Reader': [file_obj]}, outputs={'Out': out}) + if len(out) == 1: + return out[0] + else: + return out diff --git a/python/paddle/fluid/recordio_writer.py b/python/paddle/fluid/recordio_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..9735df8c06113230af9695f76a7589ea9f50e527 --- /dev/null +++ b/python/paddle/fluid/recordio_writer.py @@ -0,0 +1,45 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import core +import contextlib + +__all__ = ['convert_reader_to_recordio_file'] + + +@contextlib.contextmanager +def create_recordio_writer(filename, + compressor=core.RecordIOWriter.Compressor.Snappy, + max_num_records=1000): + writer = core.RecordIOWriter(filename, compressor, max_num_records) + yield writer + writer.close() + + +def convert_reader_to_recordio_file( + filename, + reader_creator, + feeder, + compressor=core.RecordIOWriter.Compressor.Snappy, + max_num_records=1000, + feed_order=None): + if feed_order is None: + feed_order = feeder.feed_names + with create_recordio_writer(filename, compressor, + max_num_records) as writer: + for batch in reader_creator(): + res = feeder.feed(batch) + for each in feed_order: + writer.append_tensor(res[each]) + writer.complete_append_tensor() diff --git a/python/paddle/fluid/tests/unittests/.gitignore b/python/paddle/fluid/tests/unittests/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..6b3fc2a83c649c28d21c9a8a0b35c2f2fa04f269 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/.gitignore @@ -0,0 +1 @@ +mnist.recordio diff --git a/python/paddle/fluid/tests/unittests/test_recordio_reader.py b/python/paddle/fluid/tests/unittests/test_recordio_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..d249742bd30ec41749f16beaa7076f7c6e8f063c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_recordio_reader.py @@ -0,0 +1,64 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle.fluid as fluid +import paddle.v2.dataset.mnist as mnist +import paddle.v2 as paddle + + +class TestRecordIO(unittest.TestCase): + def setUp(self): + # Convert mnist to recordio file + with fluid.program_guard(fluid.Program(), fluid.Program()): + reader = paddle.batch(mnist.train(), batch_size=32) + feeder = fluid.DataFeeder( + feed_list=[ # order is image and label + fluid.layers.data( + name='image', shape=[784]), + fluid.layers.data( + name='label', shape=[1], dtype='int64'), + ], + place=fluid.CPUPlace()) + fluid.recordio_writer.convert_reader_to_recordio_file( + './mnist.recordio', reader, feeder) + + def test_main(self): + # use new program + with fluid.program_guard(fluid.Program(), fluid.Program()): + data_file = fluid.layers.open_recordio_file( + './mnist.recordio', + shapes=[[-1, 784], [-1, 1]], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + img, label = fluid.layers.read_file(data_file) + + hidden = fluid.layers.fc(input=img, size=100, act='tanh') + prediction = fluid.layers.fc(input=hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + avg_loss = fluid.layers.mean(loss) + + fluid.optimizer.Adam(learning_rate=1e-3).minimize(avg_loss) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + avg_loss_np = [] + + # train a pass + while not data_file.eof(): + tmp, = exe.run(fetch_list=[avg_loss]) + avg_loss_np.append(tmp) + data_file.reset() + + self.assertLess(avg_loss_np[-1], avg_loss_np[0])