未验证 提交 90084a25 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #9743 from JiayiFeng/modify_readers_to_fit_parallel_executor

Modify readers to fit the parallel executor
...@@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
...@@ -22,11 +27,6 @@ limitations under the License. */ ...@@ -22,11 +27,6 @@ limitations under the License. */
#include "paddle/fluid/recordio/scanner.h" #include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h" #include "paddle/fluid/recordio/writer.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -294,7 +294,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, ...@@ -294,7 +294,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
TensorFromStream(is, static_cast<Tensor *>(tensor), dev_ctx); TensorFromStream(is, static_cast<Tensor *>(tensor), dev_ctx);
} }
void WriteToRecordIO(recordio::Writer &writer, void WriteToRecordIO(recordio::Writer *writer,
const std::vector<LoDTensor> &tensor, const std::vector<LoDTensor> &tensor,
const platform::DeviceContext &dev_ctx) { const platform::DeviceContext &dev_ctx) {
std::stringstream buffer; std::stringstream buffer;
...@@ -303,18 +303,20 @@ void WriteToRecordIO(recordio::Writer &writer, ...@@ -303,18 +303,20 @@ void WriteToRecordIO(recordio::Writer &writer,
for (auto &each : tensor) { for (auto &each : tensor) {
SerializeToStream(buffer, each, dev_ctx); SerializeToStream(buffer, each, dev_ctx);
} }
writer.Write(buffer.str()); writer->Write(buffer.str());
} }
std::vector<LoDTensor> ReadFromRecordIO( std::vector<LoDTensor> ReadFromRecordIO(
recordio::Scanner &scanner, const platform::DeviceContext &dev_ctx) { recordio::Scanner *scanner, const platform::DeviceContext &dev_ctx) {
std::istringstream sin(scanner.Next());
uint32_t sz;
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
std::vector<LoDTensor> result; std::vector<LoDTensor> result;
result.resize(sz); if (scanner->HasNext()) {
for (uint32_t i = 0; i < sz; ++i) { std::istringstream sin(scanner->Next());
DeserializeFromStream(sin, &result[i], dev_ctx); uint32_t sz;
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
result.resize(sz);
for (uint32_t i = 0; i < sz; ++i) {
DeserializeFromStream(sin, &result[i], dev_ctx);
}
} }
return result; return result;
} }
......
...@@ -15,6 +15,9 @@ limitations under the License. */ ...@@ -15,6 +15,9 @@ limitations under the License. */
#pragma once #pragma once
#include <memory> #include <memory>
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
...@@ -216,12 +219,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor, ...@@ -216,12 +219,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
void DeserializeFromStream(std::istream& is, LoDTensor* tensor, void DeserializeFromStream(std::istream& is, LoDTensor* tensor,
const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx);
extern void WriteToRecordIO(recordio::Writer& writer, extern void WriteToRecordIO(recordio::Writer* writer,
const std::vector<LoDTensor>& tensor, const std::vector<LoDTensor>& tensor,
const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx);
extern std::vector<LoDTensor> ReadFromRecordIO( extern std::vector<LoDTensor> ReadFromRecordIO(
recordio::Scanner& scanner, const platform::DeviceContext& dev_ctx); recordio::Scanner* scanner, const platform::DeviceContext& dev_ctx);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,17 +12,17 @@ ...@@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -240,8 +240,8 @@ TEST(LoDTensor, RecordIO) { ...@@ -240,8 +240,8 @@ TEST(LoDTensor, RecordIO) {
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); *platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
{ {
recordio::Writer writer(stream, recordio::Compressor::kSnappy); recordio::Writer writer(stream, recordio::Compressor::kSnappy);
WriteToRecordIO(writer, {tensor, tensor}, ctx); WriteToRecordIO(&writer, {tensor, tensor}, ctx);
WriteToRecordIO(writer, {tensor, tensor}, ctx); WriteToRecordIO(&writer, {tensor, tensor}, ctx);
writer.Flush(); writer.Flush();
} }
...@@ -254,11 +254,11 @@ TEST(LoDTensor, RecordIO) { ...@@ -254,11 +254,11 @@ TEST(LoDTensor, RecordIO) {
{ {
std::unique_ptr<std::istream> stream_ptr(stream); std::unique_ptr<std::istream> stream_ptr(stream);
recordio::Scanner scanner(std::move(stream_ptr)); recordio::Scanner scanner(std::move(stream_ptr));
auto tensors = ReadFromRecordIO(scanner, ctx); auto tensors = ReadFromRecordIO(&scanner, ctx);
ASSERT_EQ(tensors.size(), 2); ASSERT_EQ(tensors.size(), 2);
assert_tensor_ok(tensors[0]); assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]); assert_tensor_ok(tensors[1]);
tensors = ReadFromRecordIO(scanner, ctx); tensors = ReadFromRecordIO(&scanner, ctx);
ASSERT_EQ(tensors.size(), 2); ASSERT_EQ(tensors.size(), 2);
assert_tensor_ok(tensors[0]); assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]); assert_tensor_ok(tensors[1]);
......
...@@ -115,14 +115,12 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -115,14 +115,12 @@ void ParallelExecutor::BCastParamsToGPUs(
for (auto &var : vars) { for (auto &var : vars) {
auto *main_var = main_scope->FindVar(var); auto *main_var = main_scope->FindVar(var);
if (!main_var->IsType<LoDTensor>()) { if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue; continue;
} }
auto &main_tensor = main_var->Get<LoDTensor>(); auto &main_tensor = main_var->Get<LoDTensor>();
auto &dims = main_tensor.dims(); auto &dims = main_tensor.dims();
if (paddle::platform::is_gpu_place(main_tensor.place())) { if (paddle::platform::is_gpu_place(main_tensor.place())) {
size_t numel = main_tensor.numel(); size_t numel = main_tensor.numel();
ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type());
......
...@@ -22,7 +22,9 @@ FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {} ...@@ -22,7 +22,9 @@ FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
void FileReader::ReadNext(std::vector<LoDTensor> *out) { void FileReader::ReadNext(std::vector<LoDTensor> *out) {
ReadNextImpl(out); ReadNextImpl(out);
PADDLE_ENFORCE_EQ(out->size(), dims_.size()); if (out->empty()) {
return;
}
for (size_t i = 0; i < dims_.size(); ++i) { for (size_t i = 0; i < dims_.size(); ++i) {
auto &actual = out->at(i).dims(); auto &actual = out->at(i).dims();
auto &expect = dims_[i]; auto &expect = dims_[i];
......
...@@ -14,14 +14,13 @@ ...@@ -14,14 +14,13 @@
#pragma once #pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include <memory>
#include <thread>
#include <vector>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -31,8 +30,6 @@ class ReaderBase { ...@@ -31,8 +30,6 @@ class ReaderBase {
virtual void ReInit() = 0; virtual void ReInit() = 0;
virtual bool HasNext() const = 0;
virtual ~ReaderBase(); virtual ~ReaderBase();
}; };
...@@ -44,8 +41,6 @@ class DecoratedReader : public ReaderBase { ...@@ -44,8 +41,6 @@ class DecoratedReader : public ReaderBase {
void ReInit() override { reader_->ReInit(); } void ReInit() override { reader_->ReInit(); }
bool HasNext() const override { return reader_->HasNext(); }
protected: protected:
ReaderBase* reader_; ReaderBase* reader_;
}; };
...@@ -80,8 +75,6 @@ class ReaderHolder { ...@@ -80,8 +75,6 @@ class ReaderHolder {
reader_->ReInit(); reader_->ReInit();
} }
bool HasNext() const { return reader_->HasNext(); }
private: private:
std::unique_ptr<ReaderBase> reader_; std::unique_ptr<ReaderBase> reader_;
}; };
......
...@@ -66,13 +66,7 @@ class ReadOp : public framework::OperatorBase { ...@@ -66,13 +66,7 @@ class ReadOp : public framework::OperatorBase {
std::vector<std::string> out_arg_names = Outputs("Out"); std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins; std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins); reader->ReadNext(&ins);
if (ins.empty()) { PADDLE_ENFORCE(!ins.empty(), "There is no next data.");
reader->ReInit();
reader->ReadNext(&ins);
PADDLE_ENFORCE(
!ins.empty(),
"Reader can not read the next data even it has been re-initialized.");
}
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
auto* out = auto* out =
......
...@@ -22,5 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) ...@@ -22,5 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc)
# Export local libraries to parent # Export local libraries to parent
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
...@@ -63,13 +63,14 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -63,13 +63,14 @@ class DoubleBufferReader : public framework::DecoratedReader {
StartPrefetcher(); StartPrefetcher();
} }
bool HasNext() const override;
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override; void ReInit() override;
~DoubleBufferReader() { EndPrefetcher(); } ~DoubleBufferReader() { EndPrefetcher(); }
private: private:
bool HasNext() const;
void StartPrefetcher() { void StartPrefetcher() {
channel_ = framework::MakeChannel<Item>(kChannelSize); channel_ = framework::MakeChannel<Item>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
...@@ -109,7 +110,9 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { ...@@ -109,7 +110,9 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
auto place_str = Attr<std::string>("place"); auto place_str = Attr<std::string>("place");
platform::Place place; platform::Place place;
if (place_str == "CPU") { if (place_str == "AUTO") {
place = dev_place;
} else if (place_str == "CPU") {
place = platform::CPUPlace(); place = platform::CPUPlace();
} else { } else {
std::istringstream sin(place_str); std::istringstream sin(place_str);
...@@ -140,28 +143,22 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -140,28 +143,22 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
enum_range.insert(string::Sprintf("CUDA:%d", i)); enum_range.insert(string::Sprintf("CUDA:%d", i));
} }
enum_range.insert("CPU"); enum_range.insert("CPU");
AddAttr<std::string>("place", "The double buffer place, default is CPU") enum_range.insert("AUTO");
.SetDefault("CPU") AddAttr<std::string>("place", "The double buffer place")
.SetDefault("AUTO")
.InEnum({enum_range}); .InEnum({enum_range});
} }
}; };
bool DoubleBufferReader::HasNext() const {
while (!channel_->IsClosed() && !channel_->CanReceive()) {
}
return channel_->CanReceive();
}
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) { out->clear();
PADDLE_THROW("There is no next data!"); if (HasNext()) {
} Item batch;
channel_->Receive(&batch);
Item batch; *out = batch.payloads_;
channel_->Receive(&batch); if (batch.ctx_) {
*out = batch.payloads_; batch.ctx_->Wait();
if (batch.ctx_) { }
batch.ctx_->Wait();
} }
} }
...@@ -171,16 +168,26 @@ void DoubleBufferReader::ReInit() { ...@@ -171,16 +168,26 @@ void DoubleBufferReader::ReInit() {
StartPrefetcher(); StartPrefetcher();
} }
bool DoubleBufferReader::HasNext() const {
while (!channel_->IsClosed() && !channel_->CanReceive()) {
}
return channel_->CanReceive();
}
void DoubleBufferReader::PrefetchThreadFunc() { void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts."; VLOG(5) << "A new prefetch thread starts.";
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize); std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize);
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize); std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize);
size_t cached_tensor_id = 0; size_t cached_tensor_id = 0;
while (reader_->HasNext()) { while (true) {
Item batch; Item batch;
auto& cpu_batch = cpu_tensor_cache[cached_tensor_id]; auto& cpu_batch = cpu_tensor_cache[cached_tensor_id];
reader_->ReadNext(&cpu_batch); reader_->ReadNext(&cpu_batch);
if (cpu_batch.empty()) {
// The underlying reader have no next data.
break;
}
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
auto& gpu_batch = gpu_tensor_cache[cached_tensor_id]; auto& gpu_batch = gpu_tensor_cache[cached_tensor_id];
auto* gpu_ctx = ctxs_[cached_tensor_id].get(); auto* gpu_ctx = ctxs_[cached_tensor_id].get();
......
...@@ -25,22 +25,12 @@ class MultiPassReader : public framework::DecoratedReader { ...@@ -25,22 +25,12 @@ class MultiPassReader : public framework::DecoratedReader {
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {} : DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
reader_->ReadNext(out); reader_->ReadNext(out);
} if (out->empty()) {
bool HasNext() const override {
if (reader_->HasNext()) {
return true;
} else {
++pass_count_; ++pass_count_;
if (pass_count_ >= pass_num_) { if (pass_count_ < pass_num_) {
return false;
} else {
reader_->ReInit(); reader_->ReInit();
return true; reader_->ReadNext(out);
} }
} }
} }
......
...@@ -52,8 +52,6 @@ class RandomDataGenerator : public framework::ReaderBase { ...@@ -52,8 +52,6 @@ class RandomDataGenerator : public framework::ReaderBase {
void ReInit() override { return; } void ReInit() override { return; }
bool HasNext() const override { return true; }
private: private:
float min_; float min_;
float max_; float max_;
...@@ -74,7 +72,7 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase { ...@@ -74,7 +72,7 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
const auto& ranks = Attr<std::vector<int>>("ranks"); const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()), static_cast<int>(shape_concat.size()),
"The accumulate of all ranks should be equal to the " "The accumulate of all ranks should be equal to the "
"shape concat's length."); "shape concat's length.");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks); std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
......
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <mutex>
#include <thread>
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
#include "paddle/fluid/recordio/scanner.h" #include "paddle/fluid/recordio/scanner.h"
...@@ -35,17 +33,15 @@ class RecordIOFileReader : public framework::FileReader { ...@@ -35,17 +33,15 @@ class RecordIOFileReader : public framework::FileReader {
LOG(INFO) << "Creating file reader" << filename; LOG(INFO) << "Creating file reader" << filename;
} }
bool HasNext() const override { return scanner_.HasNext(); }
void ReInit() override { scanner_.Reset(); } void ReInit() override { scanner_.Reset(); }
protected: protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
if (ThreadSafe) { if (ThreadSafe) {
std::lock_guard<std::mutex> guard(*mutex_); std::lock_guard<std::mutex> guard(*mutex_);
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_); *out = framework::ReadFromRecordIO(&scanner_, dev_ctx_);
} else { } else {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_); *out = framework::ReadFromRecordIO(&scanner_, dev_ctx_);
} }
} }
...@@ -66,7 +62,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { ...@@ -66,7 +62,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
const auto& ranks = Attr<std::vector<int>>("ranks"); const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()), static_cast<int>(shape_concat.size()),
"The accumulate of all ranks should be equal to the " "The accumulate of all ranks should be equal to the "
"shape concat's length."); "shape concat's length.");
std::string filename = Attr<std::string>("filename"); std::string filename = Attr<std::string>("filename");
......
...@@ -30,35 +30,33 @@ class ShuffleReader : public framework::DecoratedReader { ...@@ -30,35 +30,33 @@ class ShuffleReader : public framework::DecoratedReader {
std::random_device device; std::random_device device;
seed_ = device(); seed_ = device();
} }
ReadIntoBuffers(); ReloadBuffer();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!HasNext()) { out->clear();
PADDLE_THROW("There is no next data!");
}
if (iteration_pos_ >= buffer_.size()) { if (iteration_pos_ >= buffer_.size()) {
VLOG(10) << "Resetting shuffle buffer"; VLOG(10) << "Resetting shuffle buffer";
ReadIntoBuffers(); ReloadBuffer();
if (buffer_.empty()) {
return;
}
} }
*out = buffer_[iteration_pos_++]; *out = buffer_[iteration_pos_++];
} }
bool HasNext() const override {
return iteration_pos_ < buffer_.size() || reader_->HasNext();
}
private: private:
void ReadIntoBuffers() { void ReloadBuffer() {
buffer_.clear(); buffer_.clear();
buffer_.reserve(buffer_size_); buffer_.reserve(buffer_size_);
iteration_pos_ = 0; iteration_pos_ = 0;
for (size_t i = 0; i < buffer_size_; ++i) { for (size_t i = 0; i < buffer_size_; ++i) {
if (!reader_->HasNext()) { std::vector<framework::LoDTensor> ins;
reader_->ReadNext(&ins);
if (ins.empty()) {
break; break;
} }
buffer_.emplace_back(); buffer_.emplace_back(ins);
reader_->ReadNext(&buffer_.back());
} }
std::mt19937 g(seed_); std::mt19937 g(seed_);
std::shuffle(buffer_.begin(), buffer_.end(), g); std::shuffle(buffer_.begin(), buffer_.end(), g);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
class ThreadedReader : public framework::DecoratedReader {
public:
ThreadedReader(ReaderBase* reader, bool safe_mode)
: DecoratedReader(reader), safe_mode_(safe_mode) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_);
reader_->ReadNext(out);
}
void ReInit() override {
if (safe_mode_) {
PADDLE_THROW(
"ThreadedReader::ReInit() is disabled when 'safe_mode' is true.");
}
VLOG(5) << "ThreadedReader::ReInit() is invoked! It might be buggy in "
"multi-thread environment.";
reader_->ReInit();
}
private:
bool safe_mode_;
std::mutex mutex_;
};
class CreateThreadedReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
auto* out = detail::Ref(scope.FindVar(Output("Out")))
.GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
bool safe_mode = Attr<bool>("safe_mode");
out->Reset(new ThreadedReader(underlying_reader.Get(), safe_mode));
}
};
class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateThreadedReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
AddAttr<bool>("safe_mode",
"When 'safe_mode' is true, 'ReInit()' is disabled to avoid "
"unexpected bugs in multi-thread environment.")
.SetDefault(true);
AddComment(R"DOC(
CreateThreadedReader Operator
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC");
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace reader = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_threaded_reader,
reader::CreateThreadedReaderOp,
reader::CreateThreadedReaderOpMaker);
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <thread> // NOLINT
#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
...@@ -19,38 +21,23 @@ namespace paddle { ...@@ -19,38 +21,23 @@ namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
class MultipleReader : public framework::ReaderBase { class MultiFileReader : public framework::ReaderBase {
public: public:
class ThreadBufferMap { MultiFileReader(const std::vector<std::string>& file_names,
public: const std::vector<framework::DDim>& dims, size_t thread_num,
std::vector<framework::LoDTensor>& operator[]( size_t buffer_size)
const std::thread::id& thread_id) { : file_names_(file_names), dims_(dims), buffer_size_(buffer_size) {
std::lock_guard<std::mutex> lock(mutex_);
return buffer_[thread_id];
}
void Clear() { buffer_.clear(); }
private:
std::mutex mutex_;
std::unordered_map<std::thread::id, std::vector<framework::LoDTensor>>
buffer_;
};
MultipleReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num)
: file_names_(file_names), dims_(dims) {
prefetchers_.resize(thread_num); prefetchers_.resize(thread_num);
StartNewScheduler(); StartNewScheduler();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNext(std::vector<framework::LoDTensor>* out) override;
bool HasNext() const override;
void ReInit() override; void ReInit() override;
~MultipleReader() { EndScheduler(); } ~MultiFileReader() { EndScheduler(); }
private: private:
bool HasNext();
void StartNewScheduler(); void StartNewScheduler();
void EndScheduler(); void EndScheduler();
void ScheduleThreadFunc(); void ScheduleThreadFunc();
...@@ -60,39 +47,36 @@ class MultipleReader : public framework::ReaderBase { ...@@ -60,39 +47,36 @@ class MultipleReader : public framework::ReaderBase {
std::vector<framework::DDim> dims_; std::vector<framework::DDim> dims_;
std::thread scheduler_; std::thread scheduler_;
std::vector<std::thread> prefetchers_; std::vector<std::thread> prefetchers_;
size_t buffer_size_;
framework::Channel<size_t>* waiting_file_idx_; framework::Channel<size_t>* waiting_file_idx_;
framework::Channel<size_t>* available_thread_idx_; framework::Channel<size_t>* available_thread_idx_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_; framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
mutable ThreadBufferMap thread_buffer_map_;
}; };
void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) { void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) { out->clear();
PADDLE_THROW("There is no next data!"); if (HasNext()) {
buffer_->Receive(out);
} }
auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()];
*out = thread_local_buffer;
thread_local_buffer.clear();
}
bool MultipleReader::HasNext() const {
auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()];
return thread_local_buffer.empty() ? buffer_->Receive(&thread_local_buffer)
: true;
} }
void MultipleReader::ReInit() { void MultiFileReader::ReInit() {
EndScheduler(); EndScheduler();
thread_buffer_map_.Clear();
StartNewScheduler(); StartNewScheduler();
} }
void MultipleReader::StartNewScheduler() { bool MultiFileReader::HasNext() {
while (!buffer_->IsClosed() && !buffer_->CanReceive()) {
}
return buffer_->CanReceive();
}
void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size(); size_t thread_num = prefetchers_.size();
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size()); waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num); available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
buffer_ = buffer_ =
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num); framework::MakeChannel<std::vector<framework::LoDTensor>>(buffer_size_);
for (size_t i = 0; i < file_names_.size(); ++i) { for (size_t i = 0; i < file_names_.size(); ++i) {
waiting_file_idx_->Send(&i); waiting_file_idx_->Send(&i);
...@@ -105,7 +89,7 @@ void MultipleReader::StartNewScheduler() { ...@@ -105,7 +89,7 @@ void MultipleReader::StartNewScheduler() {
scheduler_ = std::thread([this] { ScheduleThreadFunc(); }); scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
} }
void MultipleReader::EndScheduler() { void MultiFileReader::EndScheduler() {
available_thread_idx_->Close(); available_thread_idx_->Close();
buffer_->Close(); buffer_->Close();
waiting_file_idx_->Close(); waiting_file_idx_->Close();
...@@ -117,8 +101,8 @@ void MultipleReader::EndScheduler() { ...@@ -117,8 +101,8 @@ void MultipleReader::EndScheduler() {
delete waiting_file_idx_; delete waiting_file_idx_;
} }
void MultipleReader::ScheduleThreadFunc() { void MultiFileReader::ScheduleThreadFunc() {
VLOG(5) << "MultipleReader schedule thread starts."; VLOG(5) << "MultiFileReader schedule thread starts.";
size_t completed_thread_num = 0; size_t completed_thread_num = 0;
size_t thread_idx; size_t thread_idx;
while (available_thread_idx_->Receive(&thread_idx)) { while (available_thread_idx_->Receive(&thread_idx)) {
...@@ -150,17 +134,20 @@ void MultipleReader::ScheduleThreadFunc() { ...@@ -150,17 +134,20 @@ void MultipleReader::ScheduleThreadFunc() {
p.join(); p.join();
} }
} }
VLOG(5) << "MultipleReader schedule thread terminates."; VLOG(5) << "MultiFileReader schedule thread terminates.";
} }
void MultipleReader::PrefetchThreadFunc(std::string file_name, void MultiFileReader::PrefetchThreadFunc(std::string file_name,
size_t thread_idx) { size_t thread_idx) {
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts."; VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
std::unique_ptr<framework::ReaderBase> reader = std::unique_ptr<framework::ReaderBase> reader =
CreateReaderByFileName(file_name, dims_); CreateReaderByFileName(file_name, dims_);
while (reader->HasNext()) { while (true) {
std::vector<framework::LoDTensor> ins; std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins); reader->ReadNext(&ins);
if (ins.empty()) {
break;
}
try { try {
buffer_->Send(&ins); buffer_->Send(&ins);
} catch (paddle::platform::EnforceNotMet e) { } catch (paddle::platform::EnforceNotMet e) {
...@@ -197,11 +184,13 @@ class OpenFilesOp : public framework::OperatorBase { ...@@ -197,11 +184,13 @@ class OpenFilesOp : public framework::OperatorBase {
const auto& file_names = Attr<std::vector<std::string>>("file_names"); const auto& file_names = Attr<std::vector<std::string>>("file_names");
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!"); PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
const size_t thread_num = Attr<int>("thread_num"); const size_t thread_num = Attr<int>("thread_num");
const size_t buffer_size = Attr<int>("buffer_size");
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultipleReader( out->Reset(new MultiFileReader(file_names,
file_names, RestoreShapes(shape_concat, ranks), thread_num)); RestoreShapes(shape_concat, ranks),
thread_num, buffer_size));
} }
}; };
...@@ -212,11 +201,12 @@ class OpenFilesOpMaker : public FileReaderMakerBase { ...@@ -212,11 +201,12 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
AddAttr<std::vector<std::string>>("file_names", "Files to be read."); AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.") AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0); .GreaterThan(0);
AddAttr<int>("buffer_size", "The size of prefetch buffer.").GreaterThan(0);
AddComment(R"DOC( AddComment(R"DOC(
OpenFiles Operator OpenFiles Operator
An OpenFilesOp creates a MultipleReader, which is able to An OpenFilesOp creates a MultiFileReader, which is able to
read data multi-threaded from multiple files. read data multi-threaded from multiple files.
)DOC"); )DOC");
} }
......
...@@ -252,7 +252,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -252,7 +252,6 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "") py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("has_next", &framework::ReaderHolder::HasNext)
.def("reset", &framework::ReaderHolder::ReInit); .def("reset", &framework::ReaderHolder::ReInit);
py::class_<Scope>(m, "Scope", "") py::class_<Scope>(m, "Scope", "")
......
...@@ -39,7 +39,7 @@ class RecordIOWriter { ...@@ -39,7 +39,7 @@ class RecordIOWriter {
void CompleteAppendTensor() { void CompleteAppendTensor() {
auto& ctx = auto& ctx =
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); *platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
framework::WriteToRecordIO(writer_, tensors_, ctx); framework::WriteToRecordIO(&writer_, tensors_, ctx);
tensors_.clear(); tensors_.clear();
} }
......
...@@ -21,8 +21,7 @@ from ..executor import global_scope ...@@ -21,8 +21,7 @@ from ..executor import global_scope
__all__ = [ __all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'open_files', 'read_file', 'create_shuffle_reader', 'open_files', 'read_file', 'shuffle', 'double_buffer'
'create_double_buffer_reader', 'create_multi_pass_reader'
] ]
...@@ -237,13 +236,9 @@ def monkey_patch_reader_methods(reader): ...@@ -237,13 +236,9 @@ def monkey_patch_reader_methods(reader):
var = scope.find_var(reader.name) var = scope.find_var(reader.name)
return var.get_reader() return var.get_reader()
def eof():
return not __get_reader__().has_next()
def reset(): def reset():
return __get_reader__().reset() return __get_reader__().reset()
reader.eof = eof
reader.reset = reset reader.reset = reset
reader.stop_gradient = True reader.stop_gradient = True
reader.persistable = True reader.persistable = True
...@@ -283,7 +278,42 @@ def _copy_reader_create_op_(block, op): ...@@ -283,7 +278,42 @@ def _copy_reader_create_op_(block, op):
return new_op return new_op
def open_recordio_file(filename, shapes, lod_levels, dtypes): def open_recordio_file(filename,
shapes,
lod_levels,
dtypes,
pass_num=1,
for_parallel=False):
"""
Open a RecordIO file
This layer takes a RecordIO file to read from and returns a Reader Variable.
Via the Reader Variable, we can get data from the given RecordIO file.
Args:
filename(str): The RecordIO file's name.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Returns:
Variable: A Reader Variable via which we can get RecordIO file data.
Examples:
.. code-block:: python
reader = fluid.layers.io.open_recordio_file(
filename='./data.recordio',
shapes=[(3,224,224), (1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.read_file(reader)
"""
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = [] shape_concat = []
ranks = [] ranks = []
...@@ -310,10 +340,63 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes): ...@@ -310,10 +340,63 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var.persistable = True startup_var.persistable = True
main_prog_var = _copy_reader_var_(default_main_program().current_block(), main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var) startup_var)
if pass_num > 1:
main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num)
if for_parallel:
main_prog_var = parallel(reader=main_prog_var)
return monkey_patch_reader_methods(main_prog_var) return monkey_patch_reader_methods(main_prog_var)
def open_files(filenames, thread_num, shapes, lod_levels, dtypes): def open_files(filenames,
shapes,
lod_levels,
dtypes,
thread_num,
buffer_size=None,
pass_num=1,
for_parallel=False):
"""
Open files
This layer takes a list of files to read from and returns a Reader Variable.
Via the Reader Variable, we can get data from given files. All files must
have name suffixs to indicate their formats, e.g., '*.recordio'.
Args:
filenames(list): The list of file names.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number.
buffer_size(int): The size of prefetch buffer.
pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Returns:
Variable: A Reader Variable via which we can get file data.
Examples:
.. code-block:: python
reader = fluid.layers.io.open_files(filenames=['./data1.recordio',
'./data2.recordio'],
shapes=[(3,224,224), (1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'],
thread_num=2,
buffer_size=2)
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.io.read_file(reader)
"""
if buffer_size is None:
buffer_size = thread_num
if isinstance(filenames, basestring):
filenames = [filenames]
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = [] shape_concat = []
ranks = [] ranks = []
...@@ -322,29 +405,36 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes): ...@@ -322,29 +405,36 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
shape_concat.extend(shape) shape_concat.extend(shape)
ranks.append(len(shape)) ranks.append(len(shape))
var_name = unique_name('multiple_reader') multi_file_reader_name = unique_name('multi_file_reader')
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name) startup_reader = startup_blk.create_var(name=multi_file_reader_name)
startup_blk.append_op( startup_blk.append_op(
type='open_files', type='open_files',
outputs={'Out': [startup_var]}, outputs={'Out': [startup_reader]},
attrs={ attrs={
'shape_concat': shape_concat, 'shape_concat': shape_concat,
'lod_levels': lod_levels, 'lod_levels': lod_levels,
'ranks': ranks, 'ranks': ranks,
'file_names': filenames, 'file_names': filenames,
'thread_num': thread_num 'thread_num': thread_num,
'buffer_size': buffer_size
}) })
startup_var.desc.set_dtypes(dtypes) startup_reader.desc.set_dtypes(dtypes)
startup_var.persistable = True startup_reader.persistable = True
main_prog_var = _copy_reader_var_(default_main_program().current_block(), main_prog_reader = _copy_reader_var_(default_main_program().current_block(),
startup_var) startup_reader)
return monkey_patch_reader_methods(main_prog_var) if pass_num > 1:
main_prog_reader = multi_pass(
reader=main_prog_reader, pass_num=pass_num)
if for_parallel:
main_prog_reader = parallel(reader=main_prog_reader)
return monkey_patch_reader_methods(main_prog_reader)
def __create_decorated_reader__(op_type, reader, attrs): def __create_shared_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type) var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name) startup_var = startup_blk.create_var(name=var_name)
...@@ -360,22 +450,41 @@ def __create_decorated_reader__(op_type, reader, attrs): ...@@ -360,22 +450,41 @@ def __create_decorated_reader__(op_type, reader, attrs):
return monkey_patch_reader_methods(main_prog_var) return monkey_patch_reader_methods(main_prog_var)
def create_shuffle_reader(reader, buffer_size): def __create_unshared_decorated_reader__(op_type, reader, attrs):
return __create_decorated_reader__('create_shuffle_reader', reader, new_reader_name = unique_name(op_type)
{'buffer_size': int(buffer_size)}) main_blk = default_main_program().current_block()
new_reader = main_blk.create_var(name=new_reader_name)
main_blk.append_op(
type=op_type,
inputs={'UnderlyingReader': reader},
outputs={'Out': [new_reader]},
attrs=attrs)
new_reader.persistable = True
new_reader.stop_gradient = True
return monkey_patch_reader_methods(new_reader)
def shuffle(reader, buffer_size):
return __create_unshared_decorated_reader__(
'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)})
def create_double_buffer_reader(reader, place=None): def double_buffer(reader, place=None):
attrs = dict() attrs = dict()
if place is not None: if place is not None:
attrs['place'] = str(place).upper() attrs['place'] = str(place).upper()
return __create_decorated_reader__('create_double_buffer_reader', reader, return __create_unshared_decorated_reader__('create_double_buffer_reader',
attrs) reader, attrs)
def multi_pass(reader, pass_num):
return __create_shared_decorated_reader__(
'create_multi_pass_reader', reader, {'pass_num': int(pass_num)})
def create_multi_pass_reader(reader, pass_num): def parallel(reader):
return __create_decorated_reader__('create_multi_pass_reader', reader, return __create_shared_decorated_reader__('create_threaded_reader', reader,
{'pass_num': int(pass_num)}) {})
def read_file(file_obj): def read_file(file_obj):
......
...@@ -61,8 +61,12 @@ class TestMultipleReader(unittest.TestCase): ...@@ -61,8 +61,12 @@ class TestMultipleReader(unittest.TestCase):
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
batch_count = 0 batch_count = 0
while not data_files.eof(): while True:
img_val, = exe.run(fetch_list=[img]) try:
img_val, = exe.run(fetch_list=[img])
except fluid.core.EnforceNotMet as ex:
self.assertIn("There is no next data.", ex.message)
break
batch_count += 1 batch_count += 1
self.assertLessEqual(img_val.shape[0], self.batch_size) self.assertLessEqual(img_val.shape[0], self.batch_size)
data_files.reset() data_files.reset()
......
...@@ -44,7 +44,7 @@ class TestMultipleReader(unittest.TestCase): ...@@ -44,7 +44,7 @@ class TestMultipleReader(unittest.TestCase):
shapes=[(-1, 784), (-1, 1)], shapes=[(-1, 784), (-1, 1)],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64']) dtypes=['float32', 'int64'])
data_file = fluid.layers.create_multi_pass_reader( data_file = fluid.layers.io.multi_pass(
reader=data_file, pass_num=self.pass_num) reader=data_file, pass_num=self.pass_num)
img, label = fluid.layers.read_file(data_file) img, label = fluid.layers.read_file(data_file)
...@@ -57,8 +57,12 @@ class TestMultipleReader(unittest.TestCase): ...@@ -57,8 +57,12 @@ class TestMultipleReader(unittest.TestCase):
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
batch_count = 0 batch_count = 0
while not data_file.eof(): while True:
img_val, = exe.run(fetch_list=[img]) try:
img_val, = exe.run(fetch_list=[img])
except fluid.core.EnforceNotMet as ex:
self.assertIn("There is no next data.", ex.message)
break
batch_count += 1 batch_count += 1
self.assertLessEqual(img_val.shape[0], self.batch_size) self.assertLessEqual(img_val.shape[0], self.batch_size)
data_file.reset() data_file.reset()
......
...@@ -26,11 +26,14 @@ def simple_fc_net(use_feed): ...@@ -26,11 +26,14 @@ def simple_fc_net(use_feed):
img = fluid.layers.data(name='image', shape=[784], dtype='float32') img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
else: else:
reader = fluid.layers.open_recordio_file( reader = fluid.layers.open_files(
filename='./mnist.recordio', filenames=['./mnist.recordio'],
shapes=[[-1, 784], [-1, 1]], shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64']) dtypes=['float32', 'int64'],
thread_num=1,
for_parallel=True)
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader) img, label = fluid.layers.read_file(reader)
hidden = img hidden = img
for _ in xrange(4): for _ in xrange(4):
...@@ -51,11 +54,14 @@ def fc_with_batchnorm(use_feed): ...@@ -51,11 +54,14 @@ def fc_with_batchnorm(use_feed):
img = fluid.layers.data(name='image', shape=[784], dtype='float32') img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
else: else:
reader = fluid.layers.open_recordio_file( reader = fluid.layers.open_files(
filename='./mnist.recordio', filenames=['mnist.recordio'],
shapes=[[-1, 784], [-1, 1]], shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64']) dtypes=['float32', 'int64'],
thread_num=1,
for_parallel=True)
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader) img, label = fluid.layers.read_file(reader)
hidden = img hidden = img
......
...@@ -65,8 +65,13 @@ class TestRecordIO(unittest.TestCase): ...@@ -65,8 +65,13 @@ class TestRecordIO(unittest.TestCase):
# train a pass # train a pass
batch_id = 0 batch_id = 0
while not data_file.eof(): while True:
tmp, = exe.run(fetch_list=[avg_loss]) try:
tmp, = exe.run(fetch_list=[avg_loss])
except fluid.core.EnforceNotMet as ex:
self.assertIn("There is no next data.", ex.message)
break
avg_loss_np.append(tmp) avg_loss_np.append(tmp)
batch_id += 1 batch_id += 1
data_file.reset() data_file.reset()
...@@ -74,8 +79,8 @@ class TestRecordIO(unittest.TestCase): ...@@ -74,8 +79,8 @@ class TestRecordIO(unittest.TestCase):
self.assertLess(avg_loss_np[-1], avg_loss_np[0]) self.assertLess(avg_loss_np[-1], avg_loss_np[0])
def test_shuffle_reader(self): def test_shuffle_reader(self):
self.test_main(decorator_callback=lambda reader: fluid.layers.create_shuffle_reader(reader, buffer_size=200)) self.test_main(decorator_callback=lambda reader: fluid.layers.io.shuffle(reader, buffer_size=200))
def test_double_buffer_reader(self): def test_double_buffer_reader(self):
self.test_main(decorator_callback=lambda reader: fluid.layers.create_double_buffer_reader(reader, self.test_main(decorator_callback=lambda reader: fluid.layers.io.double_buffer(reader,
place='cuda:0' if fluid.core.is_compiled_with_cuda() else 'cpu')) place='cuda:0' if fluid.core.is_compiled_with_cuda() else 'cpu'))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册