未验证 提交 90084a25 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #9743 from JiayiFeng/modify_readers_to_fit_parallel_executor

Modify readers to fit the parallel executor
......@@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
......@@ -22,11 +27,6 @@ limitations under the License. */
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
namespace paddle {
namespace framework {
......@@ -294,7 +294,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
TensorFromStream(is, static_cast<Tensor *>(tensor), dev_ctx);
}
void WriteToRecordIO(recordio::Writer &writer,
void WriteToRecordIO(recordio::Writer *writer,
const std::vector<LoDTensor> &tensor,
const platform::DeviceContext &dev_ctx) {
std::stringstream buffer;
......@@ -303,18 +303,20 @@ void WriteToRecordIO(recordio::Writer &writer,
for (auto &each : tensor) {
SerializeToStream(buffer, each, dev_ctx);
}
writer.Write(buffer.str());
writer->Write(buffer.str());
}
std::vector<LoDTensor> ReadFromRecordIO(
recordio::Scanner &scanner, const platform::DeviceContext &dev_ctx) {
std::istringstream sin(scanner.Next());
uint32_t sz;
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
recordio::Scanner *scanner, const platform::DeviceContext &dev_ctx) {
std::vector<LoDTensor> result;
result.resize(sz);
for (uint32_t i = 0; i < sz; ++i) {
DeserializeFromStream(sin, &result[i], dev_ctx);
if (scanner->HasNext()) {
std::istringstream sin(scanner->Next());
uint32_t sz;
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
result.resize(sz);
for (uint32_t i = 0; i < sz; ++i) {
DeserializeFromStream(sin, &result[i], dev_ctx);
}
}
return result;
}
......
......@@ -15,6 +15,9 @@ limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
......@@ -216,12 +219,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
void DeserializeFromStream(std::istream& is, LoDTensor* tensor,
const platform::DeviceContext& dev_ctx);
extern void WriteToRecordIO(recordio::Writer& writer,
extern void WriteToRecordIO(recordio::Writer* writer,
const std::vector<LoDTensor>& tensor,
const platform::DeviceContext& dev_ctx);
extern std::vector<LoDTensor> ReadFromRecordIO(
recordio::Scanner& scanner, const platform::DeviceContext& dev_ctx);
recordio::Scanner* scanner, const platform::DeviceContext& dev_ctx);
} // namespace framework
} // namespace paddle
......@@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
namespace paddle {
namespace framework {
......@@ -240,8 +240,8 @@ TEST(LoDTensor, RecordIO) {
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
{
recordio::Writer writer(stream, recordio::Compressor::kSnappy);
WriteToRecordIO(writer, {tensor, tensor}, ctx);
WriteToRecordIO(writer, {tensor, tensor}, ctx);
WriteToRecordIO(&writer, {tensor, tensor}, ctx);
WriteToRecordIO(&writer, {tensor, tensor}, ctx);
writer.Flush();
}
......@@ -254,11 +254,11 @@ TEST(LoDTensor, RecordIO) {
{
std::unique_ptr<std::istream> stream_ptr(stream);
recordio::Scanner scanner(std::move(stream_ptr));
auto tensors = ReadFromRecordIO(scanner, ctx);
auto tensors = ReadFromRecordIO(&scanner, ctx);
ASSERT_EQ(tensors.size(), 2);
assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]);
tensors = ReadFromRecordIO(scanner, ctx);
tensors = ReadFromRecordIO(&scanner, ctx);
ASSERT_EQ(tensors.size(), 2);
assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]);
......
......@@ -115,14 +115,12 @@ void ParallelExecutor::BCastParamsToGPUs(
for (auto &var : vars) {
auto *main_var = main_scope->FindVar(var);
if (!main_var->IsType<LoDTensor>()) {
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue;
}
auto &main_tensor = main_var->Get<LoDTensor>();
auto &dims = main_tensor.dims();
if (paddle::platform::is_gpu_place(main_tensor.place())) {
size_t numel = main_tensor.numel();
ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type());
......
......@@ -22,7 +22,9 @@ FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
void FileReader::ReadNext(std::vector<LoDTensor> *out) {
ReadNextImpl(out);
PADDLE_ENFORCE_EQ(out->size(), dims_.size());
if (out->empty()) {
return;
}
for (size_t i = 0; i < dims_.size(); ++i) {
auto &actual = out->at(i).dims();
auto &expect = dims_[i];
......
......@@ -14,14 +14,13 @@
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"
#include <memory>
#include <thread>
#include <vector>
namespace paddle {
namespace framework {
......@@ -31,8 +30,6 @@ class ReaderBase {
virtual void ReInit() = 0;
virtual bool HasNext() const = 0;
virtual ~ReaderBase();
};
......@@ -44,8 +41,6 @@ class DecoratedReader : public ReaderBase {
void ReInit() override { reader_->ReInit(); }
bool HasNext() const override { return reader_->HasNext(); }
protected:
ReaderBase* reader_;
};
......@@ -80,8 +75,6 @@ class ReaderHolder {
reader_->ReInit();
}
bool HasNext() const { return reader_->HasNext(); }
private:
std::unique_ptr<ReaderBase> reader_;
};
......
......@@ -66,13 +66,7 @@ class ReadOp : public framework::OperatorBase {
std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (ins.empty()) {
reader->ReInit();
reader->ReadNext(&ins);
PADDLE_ENFORCE(
!ins.empty(),
"Reader can not read the next data even it has been re-initialized.");
}
PADDLE_ENFORCE(!ins.empty(), "There is no next data.");
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
for (size_t i = 0; i < ins.size(); ++i) {
auto* out =
......
......@@ -22,5 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc)
# Export local libraries to parent
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
......@@ -63,13 +63,14 @@ class DoubleBufferReader : public framework::DecoratedReader {
StartPrefetcher();
}
bool HasNext() const override;
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
~DoubleBufferReader() { EndPrefetcher(); }
private:
bool HasNext() const;
void StartPrefetcher() {
channel_ = framework::MakeChannel<Item>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
......@@ -109,7 +110,9 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
auto place_str = Attr<std::string>("place");
platform::Place place;
if (place_str == "CPU") {
if (place_str == "AUTO") {
place = dev_place;
} else if (place_str == "CPU") {
place = platform::CPUPlace();
} else {
std::istringstream sin(place_str);
......@@ -140,28 +143,22 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
enum_range.insert(string::Sprintf("CUDA:%d", i));
}
enum_range.insert("CPU");
AddAttr<std::string>("place", "The double buffer place, default is CPU")
.SetDefault("CPU")
enum_range.insert("AUTO");
AddAttr<std::string>("place", "The double buffer place")
.SetDefault("AUTO")
.InEnum({enum_range});
}
};
bool DoubleBufferReader::HasNext() const {
while (!channel_->IsClosed() && !channel_->CanReceive()) {
}
return channel_->CanReceive();
}
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
Item batch;
channel_->Receive(&batch);
*out = batch.payloads_;
if (batch.ctx_) {
batch.ctx_->Wait();
out->clear();
if (HasNext()) {
Item batch;
channel_->Receive(&batch);
*out = batch.payloads_;
if (batch.ctx_) {
batch.ctx_->Wait();
}
}
}
......@@ -171,16 +168,26 @@ void DoubleBufferReader::ReInit() {
StartPrefetcher();
}
bool DoubleBufferReader::HasNext() const {
while (!channel_->IsClosed() && !channel_->CanReceive()) {
}
return channel_->CanReceive();
}
void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts.";
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize);
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize);
size_t cached_tensor_id = 0;
while (reader_->HasNext()) {
while (true) {
Item batch;
auto& cpu_batch = cpu_tensor_cache[cached_tensor_id];
reader_->ReadNext(&cpu_batch);
if (cpu_batch.empty()) {
// The underlying reader have no next data.
break;
}
if (platform::is_gpu_place(place_)) {
auto& gpu_batch = gpu_tensor_cache[cached_tensor_id];
auto* gpu_ctx = ctxs_[cached_tensor_id].get();
......
......@@ -25,22 +25,12 @@ class MultiPassReader : public framework::DecoratedReader {
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
reader_->ReadNext(out);
}
bool HasNext() const override {
if (reader_->HasNext()) {
return true;
} else {
if (out->empty()) {
++pass_count_;
if (pass_count_ >= pass_num_) {
return false;
} else {
if (pass_count_ < pass_num_) {
reader_->ReInit();
return true;
reader_->ReadNext(out);
}
}
}
......
......@@ -52,8 +52,6 @@ class RandomDataGenerator : public framework::ReaderBase {
void ReInit() override { return; }
bool HasNext() const override { return true; }
private:
float min_;
float max_;
......@@ -74,7 +72,7 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()),
static_cast<int>(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
......
......@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include <thread>
#include "paddle/fluid/operators/reader/reader_op_registry.h"
#include "paddle/fluid/recordio/scanner.h"
......@@ -35,17 +33,15 @@ class RecordIOFileReader : public framework::FileReader {
LOG(INFO) << "Creating file reader" << filename;
}
bool HasNext() const override { return scanner_.HasNext(); }
void ReInit() override { scanner_.Reset(); }
protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
if (ThreadSafe) {
std::lock_guard<std::mutex> guard(*mutex_);
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
*out = framework::ReadFromRecordIO(&scanner_, dev_ctx_);
} else {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
*out = framework::ReadFromRecordIO(&scanner_, dev_ctx_);
}
}
......@@ -66,7 +62,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()),
static_cast<int>(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::string filename = Attr<std::string>("filename");
......
......@@ -30,35 +30,33 @@ class ShuffleReader : public framework::DecoratedReader {
std::random_device device;
seed_ = device();
}
ReadIntoBuffers();
ReloadBuffer();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
out->clear();
if (iteration_pos_ >= buffer_.size()) {
VLOG(10) << "Resetting shuffle buffer";
ReadIntoBuffers();
ReloadBuffer();
if (buffer_.empty()) {
return;
}
}
*out = buffer_[iteration_pos_++];
}
bool HasNext() const override {
return iteration_pos_ < buffer_.size() || reader_->HasNext();
}
private:
void ReadIntoBuffers() {
void ReloadBuffer() {
buffer_.clear();
buffer_.reserve(buffer_size_);
iteration_pos_ = 0;
for (size_t i = 0; i < buffer_size_; ++i) {
if (!reader_->HasNext()) {
std::vector<framework::LoDTensor> ins;
reader_->ReadNext(&ins);
if (ins.empty()) {
break;
}
buffer_.emplace_back();
reader_->ReadNext(&buffer_.back());
buffer_.emplace_back(ins);
}
std::mt19937 g(seed_);
std::shuffle(buffer_.begin(), buffer_.end(), g);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
class ThreadedReader : public framework::DecoratedReader {
public:
ThreadedReader(ReaderBase* reader, bool safe_mode)
: DecoratedReader(reader), safe_mode_(safe_mode) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_);
reader_->ReadNext(out);
}
void ReInit() override {
if (safe_mode_) {
PADDLE_THROW(
"ThreadedReader::ReInit() is disabled when 'safe_mode' is true.");
}
VLOG(5) << "ThreadedReader::ReInit() is invoked! It might be buggy in "
"multi-thread environment.";
reader_->ReInit();
}
private:
bool safe_mode_;
std::mutex mutex_;
};
class CreateThreadedReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
auto* out = detail::Ref(scope.FindVar(Output("Out")))
.GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
bool safe_mode = Attr<bool>("safe_mode");
out->Reset(new ThreadedReader(underlying_reader.Get(), safe_mode));
}
};
class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateThreadedReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
AddAttr<bool>("safe_mode",
"When 'safe_mode' is true, 'ReInit()' is disabled to avoid "
"unexpected bugs in multi-thread environment.")
.SetDefault(true);
AddComment(R"DOC(
CreateThreadedReader Operator
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC");
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace reader = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_threaded_reader,
reader::CreateThreadedReaderOp,
reader::CreateThreadedReaderOpMaker);
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thread> // NOLINT
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
......@@ -19,38 +21,23 @@ namespace paddle {
namespace operators {
namespace reader {
class MultipleReader : public framework::ReaderBase {
class MultiFileReader : public framework::ReaderBase {
public:
class ThreadBufferMap {
public:
std::vector<framework::LoDTensor>& operator[](
const std::thread::id& thread_id) {
std::lock_guard<std::mutex> lock(mutex_);
return buffer_[thread_id];
}
void Clear() { buffer_.clear(); }
private:
std::mutex mutex_;
std::unordered_map<std::thread::id, std::vector<framework::LoDTensor>>
buffer_;
};
MultipleReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num)
: file_names_(file_names), dims_(dims) {
MultiFileReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num,
size_t buffer_size)
: file_names_(file_names), dims_(dims), buffer_size_(buffer_size) {
prefetchers_.resize(thread_num);
StartNewScheduler();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
bool HasNext() const override;
void ReInit() override;
~MultipleReader() { EndScheduler(); }
~MultiFileReader() { EndScheduler(); }
private:
bool HasNext();
void StartNewScheduler();
void EndScheduler();
void ScheduleThreadFunc();
......@@ -60,39 +47,36 @@ class MultipleReader : public framework::ReaderBase {
std::vector<framework::DDim> dims_;
std::thread scheduler_;
std::vector<std::thread> prefetchers_;
size_t buffer_size_;
framework::Channel<size_t>* waiting_file_idx_;
framework::Channel<size_t>* available_thread_idx_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
mutable ThreadBufferMap thread_buffer_map_;
};
void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear();
if (HasNext()) {
buffer_->Receive(out);
}
auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()];
*out = thread_local_buffer;
thread_local_buffer.clear();
}
bool MultipleReader::HasNext() const {
auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()];
return thread_local_buffer.empty() ? buffer_->Receive(&thread_local_buffer)
: true;
}
void MultipleReader::ReInit() {
void MultiFileReader::ReInit() {
EndScheduler();
thread_buffer_map_.Clear();
StartNewScheduler();
}
void MultipleReader::StartNewScheduler() {
bool MultiFileReader::HasNext() {
while (!buffer_->IsClosed() && !buffer_->CanReceive()) {
}
return buffer_->CanReceive();
}
void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size();
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
buffer_ =
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num);
framework::MakeChannel<std::vector<framework::LoDTensor>>(buffer_size_);
for (size_t i = 0; i < file_names_.size(); ++i) {
waiting_file_idx_->Send(&i);
......@@ -105,7 +89,7 @@ void MultipleReader::StartNewScheduler() {
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
}
void MultipleReader::EndScheduler() {
void MultiFileReader::EndScheduler() {
available_thread_idx_->Close();
buffer_->Close();
waiting_file_idx_->Close();
......@@ -117,8 +101,8 @@ void MultipleReader::EndScheduler() {
delete waiting_file_idx_;
}
void MultipleReader::ScheduleThreadFunc() {
VLOG(5) << "MultipleReader schedule thread starts.";
void MultiFileReader::ScheduleThreadFunc() {
VLOG(5) << "MultiFileReader schedule thread starts.";
size_t completed_thread_num = 0;
size_t thread_idx;
while (available_thread_idx_->Receive(&thread_idx)) {
......@@ -150,17 +134,20 @@ void MultipleReader::ScheduleThreadFunc() {
p.join();
}
}
VLOG(5) << "MultipleReader schedule thread terminates.";
VLOG(5) << "MultiFileReader schedule thread terminates.";
}
void MultipleReader::PrefetchThreadFunc(std::string file_name,
size_t thread_idx) {
void MultiFileReader::PrefetchThreadFunc(std::string file_name,
size_t thread_idx) {
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
std::unique_ptr<framework::ReaderBase> reader =
CreateReaderByFileName(file_name, dims_);
while (reader->HasNext()) {
while (true) {
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (ins.empty()) {
break;
}
try {
buffer_->Send(&ins);
} catch (paddle::platform::EnforceNotMet e) {
......@@ -197,11 +184,13 @@ class OpenFilesOp : public framework::OperatorBase {
const auto& file_names = Attr<std::vector<std::string>>("file_names");
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
const size_t thread_num = Attr<int>("thread_num");
const size_t buffer_size = Attr<int>("buffer_size");
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultipleReader(
file_names, RestoreShapes(shape_concat, ranks), thread_num));
out->Reset(new MultiFileReader(file_names,
RestoreShapes(shape_concat, ranks),
thread_num, buffer_size));
}
};
......@@ -212,11 +201,12 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0);
AddAttr<int>("buffer_size", "The size of prefetch buffer.").GreaterThan(0);
AddComment(R"DOC(
OpenFiles Operator
An OpenFilesOp creates a MultipleReader, which is able to
An OpenFilesOp creates a MultiFileReader, which is able to
read data multi-threaded from multiple files.
)DOC");
}
......
......@@ -252,7 +252,6 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("has_next", &framework::ReaderHolder::HasNext)
.def("reset", &framework::ReaderHolder::ReInit);
py::class_<Scope>(m, "Scope", "")
......
......@@ -39,7 +39,7 @@ class RecordIOWriter {
void CompleteAppendTensor() {
auto& ctx =
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
framework::WriteToRecordIO(writer_, tensors_, ctx);
framework::WriteToRecordIO(&writer_, tensors_, ctx);
tensors_.clear();
}
......
......@@ -21,8 +21,7 @@ from ..executor import global_scope
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'open_files', 'read_file', 'create_shuffle_reader',
'create_double_buffer_reader', 'create_multi_pass_reader'
'open_files', 'read_file', 'shuffle', 'double_buffer'
]
......@@ -237,13 +236,9 @@ def monkey_patch_reader_methods(reader):
var = scope.find_var(reader.name)
return var.get_reader()
def eof():
return not __get_reader__().has_next()
def reset():
return __get_reader__().reset()
reader.eof = eof
reader.reset = reset
reader.stop_gradient = True
reader.persistable = True
......@@ -283,7 +278,42 @@ def _copy_reader_create_op_(block, op):
return new_op
def open_recordio_file(filename, shapes, lod_levels, dtypes):
def open_recordio_file(filename,
shapes,
lod_levels,
dtypes,
pass_num=1,
for_parallel=False):
"""
Open a RecordIO file
This layer takes a RecordIO file to read from and returns a Reader Variable.
Via the Reader Variable, we can get data from the given RecordIO file.
Args:
filename(str): The RecordIO file's name.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Returns:
Variable: A Reader Variable via which we can get RecordIO file data.
Examples:
.. code-block:: python
reader = fluid.layers.io.open_recordio_file(
filename='./data.recordio',
shapes=[(3,224,224), (1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.read_file(reader)
"""
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []
......@@ -310,10 +340,63 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var.persistable = True
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var)
if pass_num > 1:
main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num)
if for_parallel:
main_prog_var = parallel(reader=main_prog_var)
return monkey_patch_reader_methods(main_prog_var)
def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
def open_files(filenames,
shapes,
lod_levels,
dtypes,
thread_num,
buffer_size=None,
pass_num=1,
for_parallel=False):
"""
Open files
This layer takes a list of files to read from and returns a Reader Variable.
Via the Reader Variable, we can get data from given files. All files must
have name suffixs to indicate their formats, e.g., '*.recordio'.
Args:
filenames(list): The list of file names.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number.
buffer_size(int): The size of prefetch buffer.
pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Returns:
Variable: A Reader Variable via which we can get file data.
Examples:
.. code-block:: python
reader = fluid.layers.io.open_files(filenames=['./data1.recordio',
'./data2.recordio'],
shapes=[(3,224,224), (1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'],
thread_num=2,
buffer_size=2)
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.io.read_file(reader)
"""
if buffer_size is None:
buffer_size = thread_num
if isinstance(filenames, basestring):
filenames = [filenames]
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []
......@@ -322,29 +405,36 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
shape_concat.extend(shape)
ranks.append(len(shape))
var_name = unique_name('multiple_reader')
multi_file_reader_name = unique_name('multi_file_reader')
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
startup_reader = startup_blk.create_var(name=multi_file_reader_name)
startup_blk.append_op(
type='open_files',
outputs={'Out': [startup_var]},
outputs={'Out': [startup_reader]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks,
'file_names': filenames,
'thread_num': thread_num
'thread_num': thread_num,
'buffer_size': buffer_size
})
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var)
return monkey_patch_reader_methods(main_prog_var)
startup_reader.desc.set_dtypes(dtypes)
startup_reader.persistable = True
main_prog_reader = _copy_reader_var_(default_main_program().current_block(),
startup_reader)
if pass_num > 1:
main_prog_reader = multi_pass(
reader=main_prog_reader, pass_num=pass_num)
if for_parallel:
main_prog_reader = parallel(reader=main_prog_reader)
return monkey_patch_reader_methods(main_prog_reader)
def __create_decorated_reader__(op_type, reader, attrs):
def __create_shared_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
......@@ -360,22 +450,41 @@ def __create_decorated_reader__(op_type, reader, attrs):
return monkey_patch_reader_methods(main_prog_var)
def create_shuffle_reader(reader, buffer_size):
return __create_decorated_reader__('create_shuffle_reader', reader,
{'buffer_size': int(buffer_size)})
def __create_unshared_decorated_reader__(op_type, reader, attrs):
new_reader_name = unique_name(op_type)
main_blk = default_main_program().current_block()
new_reader = main_blk.create_var(name=new_reader_name)
main_blk.append_op(
type=op_type,
inputs={'UnderlyingReader': reader},
outputs={'Out': [new_reader]},
attrs=attrs)
new_reader.persistable = True
new_reader.stop_gradient = True
return monkey_patch_reader_methods(new_reader)
def shuffle(reader, buffer_size):
return __create_unshared_decorated_reader__(
'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)})
def create_double_buffer_reader(reader, place=None):
def double_buffer(reader, place=None):
attrs = dict()
if place is not None:
attrs['place'] = str(place).upper()
return __create_decorated_reader__('create_double_buffer_reader', reader,
attrs)
return __create_unshared_decorated_reader__('create_double_buffer_reader',
reader, attrs)
def multi_pass(reader, pass_num):
return __create_shared_decorated_reader__(
'create_multi_pass_reader', reader, {'pass_num': int(pass_num)})
def create_multi_pass_reader(reader, pass_num):
return __create_decorated_reader__('create_multi_pass_reader', reader,
{'pass_num': int(pass_num)})
def parallel(reader):
return __create_shared_decorated_reader__('create_threaded_reader', reader,
{})
def read_file(file_obj):
......
......@@ -61,8 +61,12 @@ class TestMultipleReader(unittest.TestCase):
exe.run(fluid.default_startup_program())
batch_count = 0
while not data_files.eof():
img_val, = exe.run(fetch_list=[img])
while True:
try:
img_val, = exe.run(fetch_list=[img])
except fluid.core.EnforceNotMet as ex:
self.assertIn("There is no next data.", ex.message)
break
batch_count += 1
self.assertLessEqual(img_val.shape[0], self.batch_size)
data_files.reset()
......
......@@ -44,7 +44,7 @@ class TestMultipleReader(unittest.TestCase):
shapes=[(-1, 784), (-1, 1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
data_file = fluid.layers.create_multi_pass_reader(
data_file = fluid.layers.io.multi_pass(
reader=data_file, pass_num=self.pass_num)
img, label = fluid.layers.read_file(data_file)
......@@ -57,8 +57,12 @@ class TestMultipleReader(unittest.TestCase):
exe.run(fluid.default_startup_program())
batch_count = 0
while not data_file.eof():
img_val, = exe.run(fetch_list=[img])
while True:
try:
img_val, = exe.run(fetch_list=[img])
except fluid.core.EnforceNotMet as ex:
self.assertIn("There is no next data.", ex.message)
break
batch_count += 1
self.assertLessEqual(img_val.shape[0], self.batch_size)
data_file.reset()
......
......@@ -26,11 +26,14 @@ def simple_fc_net(use_feed):
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
else:
reader = fluid.layers.open_recordio_file(
filename='./mnist.recordio',
reader = fluid.layers.open_files(
filenames=['./mnist.recordio'],
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
dtypes=['float32', 'int64'],
thread_num=1,
for_parallel=True)
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader)
hidden = img
for _ in xrange(4):
......@@ -51,11 +54,14 @@ def fc_with_batchnorm(use_feed):
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
else:
reader = fluid.layers.open_recordio_file(
filename='./mnist.recordio',
reader = fluid.layers.open_files(
filenames=['mnist.recordio'],
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
dtypes=['float32', 'int64'],
thread_num=1,
for_parallel=True)
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader)
hidden = img
......
......@@ -65,8 +65,13 @@ class TestRecordIO(unittest.TestCase):
# train a pass
batch_id = 0
while not data_file.eof():
tmp, = exe.run(fetch_list=[avg_loss])
while True:
try:
tmp, = exe.run(fetch_list=[avg_loss])
except fluid.core.EnforceNotMet as ex:
self.assertIn("There is no next data.", ex.message)
break
avg_loss_np.append(tmp)
batch_id += 1
data_file.reset()
......@@ -74,8 +79,8 @@ class TestRecordIO(unittest.TestCase):
self.assertLess(avg_loss_np[-1], avg_loss_np[0])
def test_shuffle_reader(self):
self.test_main(decorator_callback=lambda reader: fluid.layers.create_shuffle_reader(reader, buffer_size=200))
self.test_main(decorator_callback=lambda reader: fluid.layers.io.shuffle(reader, buffer_size=200))
def test_double_buffer_reader(self):
self.test_main(decorator_callback=lambda reader: fluid.layers.create_double_buffer_reader(reader,
self.test_main(decorator_callback=lambda reader: fluid.layers.io.double_buffer(reader,
place='cuda:0' if fluid.core.is_compiled_with_cuda() else 'cpu'))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册