未验证 提交 7d6afee5 编写于 作者: Y yuyang18

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/exception_safe_pe

...@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) ...@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim) cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
cc_test(reader_test SRCS reader_test.cc DEPS reader)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
......
...@@ -13,29 +13,61 @@ ...@@ -13,29 +13,61 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include <deque>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
ReaderBase::~ReaderBase() {}
FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {} void ReaderBase::ReadNext(std::vector<LoDTensor> *out) {
std::lock_guard<std::mutex> lock(mu_);
void FileReader::ReadNext(std::vector<LoDTensor> *out) { PADDLE_ENFORCE_EQ(status_, ReaderStatus::kRunning);
ReadNextImpl(out); ReadNextImpl(out);
if (out->empty()) { }
return;
}
PADDLE_ENFORCE_EQ(out->size(), dims_.size()); void ReaderBase::InsertDecoratedReader(
for (size_t i = 0; i < dims_.size(); ++i) { const std::shared_ptr<ReaderBase> &decorated_reader) {
auto &actual = (*out)[i].dims(); std::lock_guard<std::mutex> guard(mu_);
auto &expect = dims_[i]; decorated_readers_.emplace_back(decorated_reader);
}
PADDLE_ENFORCE_EQ(actual.size(), expect.size()); std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() {
for (int j = 0; j < actual.size(); ++j) { std::unordered_set<ReaderBase *> result;
// PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1); std::deque<ReaderBase *> queue;
queue.emplace_back(this);
while (!queue.empty()) { // BFS search
auto *front = queue.front();
queue.pop_front();
if (front->decorated_readers_.empty()) {
result.emplace(front);
} else {
for (auto &reader : front->decorated_readers_) {
if (auto *reader_ptr = reader.lock().get()) {
queue.emplace_back(reader_ptr);
}
}
} }
} }
return result;
} }
void ReaderBase::Shutdown() {
std::lock_guard<std::mutex> lock(mu_);
if (status_ != ReaderStatus::kStopped) {
ShutdownImpl();
status_ = ReaderStatus::kStopped;
}
}
void ReaderBase::Start() {
std::lock_guard<std::mutex> lock(mu_);
if (status_ != ReaderStatus::kRunning) {
StartImpl();
status_ = ReaderStatus::kRunning;
}
}
ReaderBase::~ReaderBase() { Shutdown(); }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
...@@ -24,61 +25,116 @@ ...@@ -24,61 +25,116 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
enum ReaderStatus { kRunning, kStopped };
class ReaderBase { class ReaderBase {
public: public:
virtual void ReadNext(std::vector<LoDTensor>* out) = 0; void ReadNext(std::vector<LoDTensor>* out);
void Shutdown();
virtual void ReInit() = 0; void Start();
// Return the readers which are the end of decorating chain. Basically
// they are readers just before read op.
std::unordered_set<ReaderBase*> GetEndPoints();
virtual ~ReaderBase(); virtual ~ReaderBase();
protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
virtual void ShutdownImpl() {}
virtual void StartImpl() {}
ReaderStatus status_{kRunning};
mutable std::mutex mu_;
private:
friend class DecoratedReader;
// These methods can be only invoked inside DecoratedReader to record the
// decorating chain.
void InsertDecoratedReader(
const std::shared_ptr<ReaderBase>& decorated_reader);
// A set of which readers that decorated this reader.
std::vector<std::weak_ptr<ReaderBase>> decorated_readers_;
}; };
class DecoratedReader : public ReaderBase { class DecoratedReader : public ReaderBase,
public std::enable_shared_from_this<DecoratedReader> {
public: public:
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader) explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
: ReaderBase(), reader_(reader) { : ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
} }
void ReInit() override { reader_->ReInit(); } void RegisterDecorateChain() {
reader_->InsertDecoratedReader(shared_from_this());
}
protected: protected:
std::shared_ptr<ReaderBase> reader_; void ShutdownImpl() override { reader_->Shutdown(); }
};
class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& dims);
void ReadNext(std::vector<LoDTensor>* out) override;
protected: void StartImpl() override { reader_->Start(); }
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
private: std::shared_ptr<ReaderBase> reader_;
std::vector<DDim> dims_;
}; };
// FileReader is just a conceptual class.
class FileReader : public ReaderBase {};
// The ReaderHolder is used as reader' unified wrapper, // The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables. // making it easier to access different type reader in Variables.
class ReaderHolder { class ReaderHolder {
public: public:
void Reset(ReaderBase* reader) { reader_.reset(reader); } template <typename T>
void Reset(const std::shared_ptr<T>& reader) {
auto reader_base = std::dynamic_pointer_cast<ReaderBase>(reader);
PADDLE_ENFORCE_NOT_NULL(reader_base);
reader_ = reader_base;
}
std::shared_ptr<ReaderBase> Get() const { return reader_; } const std::shared_ptr<ReaderBase>& Get() const { return reader_; }
void ReadNext(std::vector<LoDTensor>* out) { void ReadNext(std::vector<LoDTensor>* out) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReadNext(out); reader_->ReadNext(out);
} }
void ReInit() {
void ResetAll() {
auto end_readers = reader_->GetEndPoints();
for (auto* reader : end_readers) {
reader->Shutdown();
}
for (auto* reader : end_readers) {
reader->Start();
}
}
void Shutdown() {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->Shutdown();
}
void Start() {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReInit(); reader_->Start();
} }
operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; }
private: private:
std::shared_ptr<ReaderBase> reader_; std::shared_ptr<ReaderBase> reader_;
}; };
template <typename T, typename... ARGS>
inline std::shared_ptr<DecoratedReader> MakeDecoratedReader(ARGS&&... args) {
std::shared_ptr<DecoratedReader> reader(new T(std::forward<ARGS>(args)...));
reader->RegisterDecorateChain();
return reader;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
#include <memory>
#include "gtest/gtest.h"
class StubDecoratedReader : public paddle::framework::DecoratedReader {
public:
explicit StubDecoratedReader(const std::shared_ptr<ReaderBase> &reader)
: DecoratedReader(reader) {}
void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {}
};
class StubRootReader : public paddle::framework::ReaderBase {
public:
void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {}
};
TEST(READER, decorate_chain) {
auto root = std::make_shared<StubRootReader>();
auto end_point1 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
auto end_point2 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
{
auto endpoints = root->GetEndPoints();
ASSERT_EQ(endpoints.size(), 2U);
ASSERT_NE(endpoints.count(end_point1.get()), 0);
ASSERT_NE(endpoints.count(end_point2.get()), 0);
}
{
auto end_point3 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
ASSERT_EQ(root->GetEndPoints().size(), 3U);
}
{ ASSERT_EQ(root->GetEndPoints().size(), 2U); }
}
...@@ -22,7 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) ...@@ -22,7 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc)
reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc) reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc)
reader_library(create_py_reader_op SRCS create_py_reader_op.cc) reader_library(create_py_reader_op SRCS create_py_reader_op.cc)
......
...@@ -20,15 +20,19 @@ namespace reader { ...@@ -20,15 +20,19 @@ namespace reader {
class BatchReader : public framework::DecoratedReader { class BatchReader : public framework::DecoratedReader {
public: public:
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size) BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size,
: DecoratedReader(reader), batch_size_(batch_size) { bool discard_leftover)
: DecoratedReader(reader),
batch_size_(batch_size),
discard_leftover_(discard_leftover) {
buffer_.reserve(batch_size_); buffer_.reserve(batch_size_);
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
private: private:
int batch_size_; int batch_size_;
bool discard_leftover_;
std::vector<std::vector<framework::LoDTensor>> buffer_; std::vector<std::vector<framework::LoDTensor>> buffer_;
}; };
...@@ -46,8 +50,9 @@ class CreateBatchReaderOp : public framework::OperatorBase { ...@@ -46,8 +50,9 @@ class CreateBatchReaderOp : public framework::OperatorBase {
} }
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
out->Reset( out->Reset(framework::MakeDecoratedReader<BatchReader>(
new BatchReader(underlying_reader.Get(), Attr<int>("batch_size"))); underlying_reader, Attr<int>("batch_size"),
Attr<bool>("discard_leftover")));
} }
}; };
...@@ -57,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -57,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
AddAttr<int>("batch_size", AddAttr<int>("batch_size",
"How many instances the batch reader yields each time.") "How many instances the batch reader yields each time.")
.GreaterThan(0); .GreaterThan(0);
AddAttr<bool>("discard_leftover",
"If true, the leftover instances that are not enough for a "
"new batch will be discarded.")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
CreateBatchReader Operator CreateBatchReader Operator
...@@ -66,7 +75,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -66,7 +75,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
} }
}; };
void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) { void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
buffer_.clear(); buffer_.clear();
buffer_.reserve(batch_size_); buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) { for (int i = 0; i < batch_size_; ++i) {
...@@ -77,6 +86,9 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) { ...@@ -77,6 +86,9 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
break; break;
} }
} }
if (discard_leftover_ && buffer_.size() < batch_size_) {
buffer_.clear();
}
// Concat instances // Concat instances
out->clear(); out->clear();
if (buffer_.empty()) { if (buffer_.empty()) {
......
...@@ -33,7 +33,7 @@ class CustomReader : public framework::DecoratedReader { ...@@ -33,7 +33,7 @@ class CustomReader : public framework::DecoratedReader {
source_var_names_(source_var_names), source_var_names_(source_var_names),
sink_var_names_(sink_var_names) {} sink_var_names_(sink_var_names) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
private: private:
const framework::ProgramDesc program_; const framework::ProgramDesc program_;
...@@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase { ...@@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase {
} }
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
out->Reset( out->Reset(framework::MakeDecoratedReader<CustomReader>(
new CustomReader(underlying_reader.Get(), *sub_block, underlying_reader, *sub_block,
Attr<std::vector<std::string>>("source_var_names"), Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names"))); Attr<std::vector<std::string>>("sink_var_names")));
} }
}; };
...@@ -143,7 +143,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference { ...@@ -143,7 +143,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
} }
}; };
void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) { void CustomReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
out->clear(); out->clear();
std::vector<framework::LoDTensor> underlying_outs; std::vector<framework::LoDTensor> underlying_outs;
reader_->ReadNext(&underlying_outs); reader_->ReadNext(&underlying_outs);
......
...@@ -23,13 +23,13 @@ namespace reader { ...@@ -23,13 +23,13 @@ namespace reader {
// 'Double buffer' means we shall maintain two batches of input data at the same // 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2. // time. So the kCacheSize shoul be at least 2.
static constexpr size_t kCacheSize = 5; static constexpr size_t kCacheSize = 3;
// There will be two bacthes out of the channel during training: // There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel // 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by // 2. the one just be received from the channel, which is also being used by
// subsequent operators. // subsequent operators.
// So the channel size should be kChacheSize - 2 // So the channel size should be kChacheSize - 2
static constexpr size_t kChannelSize = 3; // kCacheSize - 2 static constexpr size_t kChannelSize = 1; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader { class DoubleBufferReader : public framework::DecoratedReader {
public: public:
...@@ -50,12 +50,21 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -50,12 +50,21 @@ class DoubleBufferReader : public framework::DecoratedReader {
StartPrefetcher(); StartPrefetcher();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
~DoubleBufferReader() { EndPrefetcher(); } ~DoubleBufferReader() { EndPrefetcher(); }
private: private:
void ShutdownImpl() override {
EndPrefetcher();
reader_->Shutdown();
}
void StartImpl() override {
reader_->Start();
StartPrefetcher();
}
void StartPrefetcher() { void StartPrefetcher() {
channel_ = new reader::BlockingQueue<size_t>(kChannelSize); channel_ = new reader::BlockingQueue<size_t>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
...@@ -109,7 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { ...@@ -109,7 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place = platform::CUDAPlace(static_cast<int>(num)); place = platform::CUDAPlace(static_cast<int>(num));
} }
out->Reset(new DoubleBufferReader(underlying_reader.Get(), place)); out->Reset(framework::MakeDecoratedReader<DoubleBufferReader>(
underlying_reader, place));
} }
}; };
...@@ -136,7 +146,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -136,7 +146,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
} }
}; };
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
size_t cached_tensor_id; size_t cached_tensor_id;
if (channel_->Receive(&cached_tensor_id)) { if (channel_->Receive(&cached_tensor_id)) {
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
...@@ -150,12 +160,6 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { ...@@ -150,12 +160,6 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
} }
} }
void DoubleBufferReader::ReInit() {
reader_->ReInit();
EndPrefetcher();
StartPrefetcher();
}
void DoubleBufferReader::PrefetchThreadFunc() { void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts."; VLOG(5) << "A new prefetch thread starts.";
size_t cached_tensor_id = 0; size_t cached_tensor_id = 0;
......
...@@ -24,23 +24,22 @@ class MultiPassReader : public framework::DecoratedReader { ...@@ -24,23 +24,22 @@ class MultiPassReader : public framework::DecoratedReader {
MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num) MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num)
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {} : DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
reader_->ReadNext(out); reader_->ReadNext(out);
if (out->empty()) { if (out->empty() && pass_count_ < pass_num_ - 1) {
reader_->Shutdown();
reader_->Start();
reader_->ReadNext(out);
++pass_count_; ++pass_count_;
if (pass_count_ < pass_num_) {
reader_->ReInit();
reader_->ReadNext(out);
}
} }
} }
void ReInit() override { private:
void StartImpl() override {
pass_count_ = 0; pass_count_ = 0;
reader_->ReInit(); reader_->Start();
} }
private:
int pass_num_; int pass_num_;
mutable int pass_count_; mutable int pass_count_;
}; };
...@@ -60,7 +59,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase { ...@@ -60,7 +59,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
int pass_num = Attr<int>("pass_num"); int pass_num = Attr<int>("pass_num");
out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num)); out->Reset(framework::MakeDecoratedReader<MultiPassReader>(
underlying_reader, pass_num));
} }
}; };
......
...@@ -19,22 +19,27 @@ namespace paddle { ...@@ -19,22 +19,27 @@ namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
class PyReader : public framework::ReaderBase { class PyReader : public framework::FileReader {
public: public:
explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue) { explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue)
: framework::FileReader() {
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
queue_ = queue; queue_ = queue;
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
bool success; bool success;
*out = queue_->Pop(&success); *out = queue_->Pop(&success);
if (!success) out->clear(); if (!success) out->clear();
} }
void ReInit() override {}
private: private:
void ShutdownImpl() override { /* TODO */
}
void StartImpl() override { /* TODO */
}
std::shared_ptr<LoDTensorBlockingQueue> queue_; std::shared_ptr<LoDTensorBlockingQueue> queue_;
}; };
...@@ -51,14 +56,14 @@ class CreatePyReaderOp : public framework::OperatorBase { ...@@ -51,14 +56,14 @@ class CreatePyReaderOp : public framework::OperatorBase {
const std::string& queue_name = Input("blocking_queue"); const std::string& queue_name = Input("blocking_queue");
auto* queue_holder_var = scope.FindVar(queue_name); auto* queue_holder_var = scope.FindVar(queue_name);
PADDLE_ENFORCE( PADDLE_ENFORCE_NOT_NULL(
queue_holder_var != nullptr, queue_holder_var,
"No LoDTensorBlockingQueueHolder variable with name %s found", "No LoDTensorBlockingQueueHolder variable with name %s found",
queue_name); queue_name);
auto* queue_holder = auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>(); queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
out->Reset(new PyReader(queue_holder->GetQueue())); out->Reset(std::make_shared<PyReader>(queue_holder->GetQueue()));
} }
}; };
......
...@@ -19,11 +19,11 @@ namespace operators { ...@@ -19,11 +19,11 @@ namespace operators {
namespace reader { namespace reader {
template <typename T> template <typename T>
class RandomDataGenerator : public framework::ReaderBase { class RandomDataGenerator : public framework::FileReader {
public: public:
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float low, RandomDataGenerator(const std::vector<framework::DDim>& shapes, float low,
float high) float high)
: framework::ReaderBase(), low_(low), high_(high), shapes_(shapes) { : framework::FileReader(), low_(low), high_(high), shapes_(shapes) {
PADDLE_ENFORCE_LE(low, high, PADDLE_ENFORCE_LE(low, high,
"'low' shouldn't be greater than 'high'.(%f vs %f)", low, "'low' shouldn't be greater than 'high'.(%f vs %f)", low,
high); high);
...@@ -32,7 +32,7 @@ class RandomDataGenerator : public framework::ReaderBase { ...@@ -32,7 +32,7 @@ class RandomDataGenerator : public framework::ReaderBase {
dist_ = std::uniform_real_distribution<float>(low_, high_); dist_ = std::uniform_real_distribution<float>(low_, high_);
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
out->clear(); out->clear();
out->reserve(shapes_.size()); out->reserve(shapes_.size());
for (const framework::DDim& shape : shapes_) { for (const framework::DDim& shape : shapes_) {
...@@ -51,8 +51,6 @@ class RandomDataGenerator : public framework::ReaderBase { ...@@ -51,8 +51,6 @@ class RandomDataGenerator : public framework::ReaderBase {
} }
} }
void ReInit() override { return; }
private: private:
float low_; float low_;
float high_; float high_;
...@@ -79,8 +77,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase { ...@@ -79,8 +77,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks); std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new RandomDataGenerator<T>(shapes, Attr<float>("low"), out->Reset(std::make_shared<RandomDataGenerator<T>>(
Attr<float>("high"))); shapes, Attr<float>("low"), Attr<float>("high")));
} }
}; };
......
...@@ -21,10 +21,8 @@ namespace reader { ...@@ -21,10 +21,8 @@ namespace reader {
template <bool ThreadSafe> template <bool ThreadSafe>
class RecordIOFileReader : public framework::FileReader { class RecordIOFileReader : public framework::FileReader {
public: public:
explicit RecordIOFileReader(const std::string& filename, explicit RecordIOFileReader(const std::string& filename)
const std::vector<framework::DDim>& dims) : scanner_(filename),
: FileReader(dims),
scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get( dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) { platform::CPUPlace())) {
if (ThreadSafe) { if (ThreadSafe) {
...@@ -33,8 +31,6 @@ class RecordIOFileReader : public framework::FileReader { ...@@ -33,8 +31,6 @@ class RecordIOFileReader : public framework::FileReader {
LOG(INFO) << "Creating file reader" << filename; LOG(INFO) << "Creating file reader" << filename;
} }
void ReInit() override { scanner_.Reset(); }
protected: protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
if (ThreadSafe) { if (ThreadSafe) {
...@@ -45,6 +41,8 @@ class RecordIOFileReader : public framework::FileReader { ...@@ -45,6 +41,8 @@ class RecordIOFileReader : public framework::FileReader {
} }
} }
void StartImpl() override { scanner_.Reset(); }
private: private:
std::unique_ptr<std::mutex> mutex_; std::unique_ptr<std::mutex> mutex_;
recordio::Scanner scanner_; recordio::Scanner scanner_;
...@@ -58,20 +56,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { ...@@ -58,20 +56,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
static_cast<int>(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::string filename = Attr<std::string>("filename"); std::string filename = Attr<std::string>("filename");
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader<true>( out->Reset(std::make_shared<RecordIOFileReader<true>>(filename));
filename, RestoreShapes(shape_concat, ranks)));
} }
}; };
......
...@@ -34,7 +34,7 @@ class ShuffleReader : public framework::DecoratedReader { ...@@ -34,7 +34,7 @@ class ShuffleReader : public framework::DecoratedReader {
ReloadBuffer(); ReloadBuffer();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
out->clear(); out->clear();
if (iteration_pos_ >= buffer_.size()) { if (iteration_pos_ >= buffer_.size()) {
VLOG(10) << "Resetting shuffle buffer"; VLOG(10) << "Resetting shuffle buffer";
...@@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader { ...@@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader {
} }
private: private:
void ShutdownImpl() override {
buffer_.clear();
iteration_pos_ = 0;
reader_->Shutdown();
}
void StartImpl() override {
reader_->Start();
ReloadBuffer();
}
void ReloadBuffer() { void ReloadBuffer() {
buffer_.clear(); buffer_.clear();
buffer_.reserve(buffer_size_); buffer_.reserve(buffer_size_);
...@@ -86,9 +97,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase { ...@@ -86,9 +97,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
} }
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
out->Reset( out->Reset(framework::MakeDecoratedReader<ShuffleReader>(
new ShuffleReader(underlying_reader.Get(), underlying_reader, static_cast<size_t>(Attr<int>("buffer_size"))));
static_cast<size_t>(Attr<int>("buffer_size"))));
} }
}; };
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
class ThreadedReader : public framework::DecoratedReader {
public:
explicit ThreadedReader(const std::shared_ptr<ReaderBase>& reader)
: DecoratedReader(reader) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_);
reader_->ReadNext(out);
}
void ReInit() override { reader_->ReInit(); }
private:
std::mutex mutex_;
};
class CreateThreadedReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
auto* out = detail::Ref(scope.FindVar(Output("Out")))
.GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(new ThreadedReader(underlying_reader.Get()));
}
};
class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
protected:
void Apply() override {
AddComment(R"DOC(
CreateThreadedReader Operator
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC");
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace reader = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_threaded_reader,
reader::CreateThreadedReaderOp,
reader::CreateThreadedReaderOpMaker);
...@@ -23,24 +23,26 @@ namespace reader { ...@@ -23,24 +23,26 @@ namespace reader {
class MultiFileReader : public framework::ReaderBase { class MultiFileReader : public framework::ReaderBase {
public: public:
MultiFileReader(const std::vector<std::string>& file_names, MultiFileReader(const std::vector<std::string>& file_names, size_t thread_num,
const std::vector<framework::DDim>& dims, size_t thread_num,
size_t buffer_size) size_t buffer_size)
: buffer_size_(buffer_size) { : buffer_size_(buffer_size) {
readers_.reserve(file_names.size()); readers_.reserve(file_names.size());
for (const std::string& f_name : file_names) { for (const std::string& f_name : file_names) {
readers_.emplace_back(CreateReaderByFileName(f_name, dims)); readers_.emplace_back(CreateReaderByFileName(f_name));
} }
prefetchers_.resize(thread_num); prefetchers_.resize(thread_num);
StartNewScheduler(); StartNewScheduler();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
~MultiFileReader() { EndScheduler(); } ~MultiFileReader() { EndScheduler(); }
private: private:
void ShutdownImpl() override { EndScheduler(); }
void StartImpl() override { StartNewScheduler(); }
void StartNewScheduler(); void StartNewScheduler();
void EndScheduler(); void EndScheduler();
void ScheduleThreadFunc(); void ScheduleThreadFunc();
...@@ -55,17 +57,12 @@ class MultiFileReader : public framework::ReaderBase { ...@@ -55,17 +57,12 @@ class MultiFileReader : public framework::ReaderBase {
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_; reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
}; };
void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) { void MultiFileReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
if (!buffer_->Receive(out)) { if (!buffer_->Receive(out)) {
out->clear(); out->clear();
} }
} }
void MultiFileReader::ReInit() {
EndScheduler();
StartNewScheduler();
}
void MultiFileReader::StartNewScheduler() { void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size(); size_t thread_num = prefetchers_.size();
waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size()); waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size());
...@@ -120,7 +117,7 @@ void MultiFileReader::ScheduleThreadFunc() { ...@@ -120,7 +117,7 @@ void MultiFileReader::ScheduleThreadFunc() {
} }
} }
} }
// If users invoke ReInit() when scheduler is running, it will close the // If users invoke Shutdown() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler // 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends. // to release their resource. So a check is needed before scheduler ends.
for (auto& p : prefetchers_) { for (auto& p : prefetchers_) {
...@@ -138,7 +135,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) { ...@@ -138,7 +135,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
std::vector<framework::LoDTensor> ins; std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins); reader->ReadNext(&ins);
if (ins.empty()) { if (ins.empty()) {
reader->ReInit(); reader->Shutdown();
reader->Start();
break; break;
} }
try { try {
...@@ -180,9 +178,8 @@ class OpenFilesOp : public framework::OperatorBase { ...@@ -180,9 +178,8 @@ class OpenFilesOp : public framework::OperatorBase {
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultiFileReader(file_names, out->Reset(
RestoreShapes(shape_concat, ranks), std::make_shared<MultiFileReader>(file_names, thread_num, buffer_size));
thread_num, buffer_size));
} }
}; };
......
...@@ -39,7 +39,7 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() { ...@@ -39,7 +39,7 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
} }
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims) { const std::string& file_name) {
size_t separator_pos = file_name.find_last_of(kFileFormatSeparator); size_t separator_pos = file_name.find_last_of(kFileFormatSeparator);
PADDLE_ENFORCE_NE(separator_pos, std::string::npos, PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
"File name illegal! A legal file name should be like: " "File name illegal! A legal file name should be like: "
...@@ -49,7 +49,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( ...@@ -49,7 +49,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
auto itor = FileReaderRegistry().find(filetype); auto itor = FileReaderRegistry().find(filetype);
PADDLE_ENFORCE(itor != FileReaderRegistry().end(), PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
"No file reader registered for '%s' format.", filetype); "No file reader registered for '%s' format.", filetype);
framework::ReaderBase* reader = (itor->second)(file_name, dims); framework::ReaderBase* reader = (itor->second)(file_name);
return std::unique_ptr<framework::ReaderBase>(reader); return std::unique_ptr<framework::ReaderBase>(reader);
} }
......
...@@ -25,22 +25,21 @@ namespace reader { ...@@ -25,22 +25,21 @@ namespace reader {
static constexpr char kFileFormatSeparator[] = "."; static constexpr char kFileFormatSeparator[] = ".";
using FileReaderCreator = std::function<framework::ReaderBase*( using FileReaderCreator =
const std::string&, const std::vector<framework::DDim>&)>; std::function<framework::ReaderBase*(const std::string&)>;
std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry(); std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
template <typename Reader> template <typename Reader>
int RegisterFileReader(const std::string& filetype) { int RegisterFileReader(const std::string& filetype) {
FileReaderRegistry()[filetype] = []( FileReaderRegistry()[filetype] = [](const std::string& fn) {
const std::string& fn, const std::vector<framework::DDim>& dims) { return new Reader(fn);
return new Reader(fn, dims);
}; };
return 0; return 0;
} }
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims); const std::string& file_name);
extern std::vector<framework::DDim> RestoreShapes( extern std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks); const std::vector<int>& shape_concat, const std::vector<int>& ranks);
......
...@@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "") py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("reset", &framework::ReaderHolder::ReInit); .def("reset", &framework::ReaderHolder::ResetAll);
using LoDTensorBlockingQueue = using LoDTensorBlockingQueue =
::paddle::operators::reader::LoDTensorBlockingQueue; ::paddle::operators::reader::LoDTensorBlockingQueue;
......
...@@ -375,9 +375,6 @@ def open_recordio_file(filename, ...@@ -375,9 +375,6 @@ def open_recordio_file(filename,
if pass_num > 1: if pass_num > 1:
main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num) main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num)
if for_parallel:
main_prog_var = parallel(reader=main_prog_var)
return monkey_patch_reader_methods(main_prog_var) return monkey_patch_reader_methods(main_prog_var)
...@@ -529,9 +526,6 @@ def open_files(filenames, ...@@ -529,9 +526,6 @@ def open_files(filenames,
main_prog_reader = multi_pass( main_prog_reader = multi_pass(
reader=main_prog_reader, pass_num=pass_num) reader=main_prog_reader, pass_num=pass_num)
if for_parallel:
main_prog_reader = parallel(reader=main_prog_reader)
return monkey_patch_reader_methods(main_prog_reader) return monkey_patch_reader_methods(main_prog_reader)
...@@ -647,11 +641,6 @@ def multi_pass(reader, pass_num): ...@@ -647,11 +641,6 @@ def multi_pass(reader, pass_num):
'create_multi_pass_reader', reader, {'pass_num': int(pass_num)}) 'create_multi_pass_reader', reader, {'pass_num': int(pass_num)})
def parallel(reader):
return __create_shared_decorated_reader__('create_threaded_reader', reader,
{})
def read_file(reader): def read_file(reader):
""" """
Execute the given reader and get data via it. Execute the given reader and get data via it.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册