提交 a84b8150 编写于 作者: F fengjiayi

Remove Readers' HasNext()

上级 53aea5e1
...@@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
...@@ -22,11 +27,6 @@ limitations under the License. */ ...@@ -22,11 +27,6 @@ limitations under the License. */
#include "paddle/fluid/recordio/scanner.h" #include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h" #include "paddle/fluid/recordio/writer.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -294,7 +294,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, ...@@ -294,7 +294,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
TensorFromStream(is, static_cast<Tensor *>(tensor), dev_ctx); TensorFromStream(is, static_cast<Tensor *>(tensor), dev_ctx);
} }
void WriteToRecordIO(recordio::Writer &writer, void WriteToRecordIO(recordio::Writer *writer,
const std::vector<LoDTensor> &tensor, const std::vector<LoDTensor> &tensor,
const platform::DeviceContext &dev_ctx) { const platform::DeviceContext &dev_ctx) {
std::stringstream buffer; std::stringstream buffer;
...@@ -303,18 +303,20 @@ void WriteToRecordIO(recordio::Writer &writer, ...@@ -303,18 +303,20 @@ void WriteToRecordIO(recordio::Writer &writer,
for (auto &each : tensor) { for (auto &each : tensor) {
SerializeToStream(buffer, each, dev_ctx); SerializeToStream(buffer, each, dev_ctx);
} }
writer.Write(buffer.str()); writer->Write(buffer.str());
} }
std::vector<LoDTensor> ReadFromRecordIO( std::vector<LoDTensor> ReadFromRecordIO(
recordio::Scanner &scanner, const platform::DeviceContext &dev_ctx) { recordio::Scanner *scanner, const platform::DeviceContext &dev_ctx) {
std::istringstream sin(scanner.Next());
uint32_t sz;
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
std::vector<LoDTensor> result; std::vector<LoDTensor> result;
result.resize(sz); if (scanner->HasNext()) {
for (uint32_t i = 0; i < sz; ++i) { std::istringstream sin(scanner->Next());
DeserializeFromStream(sin, &result[i], dev_ctx); uint32_t sz;
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
result.resize(sz);
for (uint32_t i = 0; i < sz; ++i) {
DeserializeFromStream(sin, &result[i], dev_ctx);
}
} }
return result; return result;
} }
......
...@@ -15,6 +15,9 @@ limitations under the License. */ ...@@ -15,6 +15,9 @@ limitations under the License. */
#pragma once #pragma once
#include <memory> #include <memory>
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
...@@ -216,12 +219,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor, ...@@ -216,12 +219,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
void DeserializeFromStream(std::istream& is, LoDTensor* tensor, void DeserializeFromStream(std::istream& is, LoDTensor* tensor,
const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx);
extern void WriteToRecordIO(recordio::Writer& writer, extern void WriteToRecordIO(recordio::Writer* writer,
const std::vector<LoDTensor>& tensor, const std::vector<LoDTensor>& tensor,
const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx);
extern std::vector<LoDTensor> ReadFromRecordIO( extern std::vector<LoDTensor> ReadFromRecordIO(
recordio::Scanner& scanner, const platform::DeviceContext& dev_ctx); recordio::Scanner* scanner, const platform::DeviceContext& dev_ctx);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,17 +12,17 @@ ...@@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -240,8 +240,8 @@ TEST(LoDTensor, RecordIO) { ...@@ -240,8 +240,8 @@ TEST(LoDTensor, RecordIO) {
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); *platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
{ {
recordio::Writer writer(stream, recordio::Compressor::kSnappy); recordio::Writer writer(stream, recordio::Compressor::kSnappy);
WriteToRecordIO(writer, {tensor, tensor}, ctx); WriteToRecordIO(&writer, {tensor, tensor}, ctx);
WriteToRecordIO(writer, {tensor, tensor}, ctx); WriteToRecordIO(&writer, {tensor, tensor}, ctx);
writer.Flush(); writer.Flush();
} }
...@@ -254,11 +254,11 @@ TEST(LoDTensor, RecordIO) { ...@@ -254,11 +254,11 @@ TEST(LoDTensor, RecordIO) {
{ {
std::unique_ptr<std::istream> stream_ptr(stream); std::unique_ptr<std::istream> stream_ptr(stream);
recordio::Scanner scanner(std::move(stream_ptr)); recordio::Scanner scanner(std::move(stream_ptr));
auto tensors = ReadFromRecordIO(scanner, ctx); auto tensors = ReadFromRecordIO(&scanner, ctx);
ASSERT_EQ(tensors.size(), 2); ASSERT_EQ(tensors.size(), 2);
assert_tensor_ok(tensors[0]); assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]); assert_tensor_ok(tensors[1]);
tensors = ReadFromRecordIO(scanner, ctx); tensors = ReadFromRecordIO(&scanner, ctx);
ASSERT_EQ(tensors.size(), 2); ASSERT_EQ(tensors.size(), 2);
assert_tensor_ok(tensors[0]); assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]); assert_tensor_ok(tensors[1]);
......
...@@ -14,14 +14,13 @@ ...@@ -14,14 +14,13 @@
#pragma once #pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include <memory>
#include <thread>
#include <vector>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -31,8 +30,6 @@ class ReaderBase { ...@@ -31,8 +30,6 @@ class ReaderBase {
virtual void ReInit() = 0; virtual void ReInit() = 0;
virtual bool HasNext() const = 0;
virtual ~ReaderBase(); virtual ~ReaderBase();
}; };
...@@ -44,8 +41,6 @@ class DecoratedReader : public ReaderBase { ...@@ -44,8 +41,6 @@ class DecoratedReader : public ReaderBase {
void ReInit() override { reader_->ReInit(); } void ReInit() override { reader_->ReInit(); }
bool HasNext() const override { return reader_->HasNext(); }
protected: protected:
ReaderBase* reader_; ReaderBase* reader_;
}; };
...@@ -80,8 +75,6 @@ class ReaderHolder { ...@@ -80,8 +75,6 @@ class ReaderHolder {
reader_->ReInit(); reader_->ReInit();
} }
bool HasNext() const { return reader_->HasNext(); }
private: private:
std::unique_ptr<ReaderBase> reader_; std::unique_ptr<ReaderBase> reader_;
}; };
......
...@@ -63,13 +63,14 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -63,13 +63,14 @@ class DoubleBufferReader : public framework::DecoratedReader {
StartPrefetcher(); StartPrefetcher();
} }
bool HasNext() const override;
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override; void ReInit() override;
~DoubleBufferReader() { EndPrefetcher(); } ~DoubleBufferReader() { EndPrefetcher(); }
private: private:
bool HasNext() const;
void StartPrefetcher() { void StartPrefetcher() {
channel_ = framework::MakeChannel<Item>(kChannelSize); channel_ = framework::MakeChannel<Item>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
...@@ -149,22 +150,15 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -149,22 +150,15 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
} }
}; };
bool DoubleBufferReader::HasNext() const {
while (!channel_->IsClosed() && !channel_->CanReceive()) {
}
return channel_->CanReceive();
}
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) { out->clear();
PADDLE_THROW("There is no next data!"); if (HasNext()) {
} Item batch;
channel_->Receive(&batch);
Item batch; *out = batch.payloads_;
channel_->Receive(&batch); if (batch.ctx_) {
*out = batch.payloads_; batch.ctx_->Wait();
if (batch.ctx_) { }
batch.ctx_->Wait();
} }
} }
...@@ -174,16 +168,26 @@ void DoubleBufferReader::ReInit() { ...@@ -174,16 +168,26 @@ void DoubleBufferReader::ReInit() {
StartPrefetcher(); StartPrefetcher();
} }
bool DoubleBufferReader::HasNext() const {
while (!channel_->IsClosed() && !channel_->CanReceive()) {
}
return channel_->CanReceive();
}
void DoubleBufferReader::PrefetchThreadFunc() { void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts."; VLOG(5) << "A new prefetch thread starts.";
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize); std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize);
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize); std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize);
size_t cached_tensor_id = 0; size_t cached_tensor_id = 0;
while (reader_->HasNext()) { while (true) {
Item batch; Item batch;
auto& cpu_batch = cpu_tensor_cache[cached_tensor_id]; auto& cpu_batch = cpu_tensor_cache[cached_tensor_id];
reader_->ReadNext(&cpu_batch); reader_->ReadNext(&cpu_batch);
if (cpu_batch.empty()) {
// The underlying reader have no next data.
break;
}
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
auto& gpu_batch = gpu_tensor_cache[cached_tensor_id]; auto& gpu_batch = gpu_tensor_cache[cached_tensor_id];
auto* gpu_ctx = ctxs_[cached_tensor_id].get(); auto* gpu_ctx = ctxs_[cached_tensor_id].get();
......
...@@ -25,22 +25,12 @@ class MultiPassReader : public framework::DecoratedReader { ...@@ -25,22 +25,12 @@ class MultiPassReader : public framework::DecoratedReader {
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {} : DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
reader_->ReadNext(out); reader_->ReadNext(out);
} if (out->empty()) {
bool HasNext() const override {
if (reader_->HasNext()) {
return true;
} else {
++pass_count_; ++pass_count_;
if (pass_count_ >= pass_num_) { if (pass_count_ < pass_num_) {
return false;
} else {
reader_->ReInit(); reader_->ReInit();
return true; reader_->ReadNext(out);
} }
} }
} }
......
...@@ -52,8 +52,6 @@ class RandomDataGenerator : public framework::ReaderBase { ...@@ -52,8 +52,6 @@ class RandomDataGenerator : public framework::ReaderBase {
void ReInit() override { return; } void ReInit() override { return; }
bool HasNext() const override { return true; }
private: private:
float min_; float min_;
float max_; float max_;
...@@ -74,7 +72,7 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase { ...@@ -74,7 +72,7 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
const auto& ranks = Attr<std::vector<int>>("ranks"); const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()), static_cast<int>(shape_concat.size()),
"The accumulate of all ranks should be equal to the " "The accumulate of all ranks should be equal to the "
"shape concat's length."); "shape concat's length.");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks); std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
......
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <mutex>
#include <thread>
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
#include "paddle/fluid/recordio/scanner.h" #include "paddle/fluid/recordio/scanner.h"
...@@ -35,17 +33,15 @@ class RecordIOFileReader : public framework::FileReader { ...@@ -35,17 +33,15 @@ class RecordIOFileReader : public framework::FileReader {
LOG(INFO) << "Creating file reader" << filename; LOG(INFO) << "Creating file reader" << filename;
} }
bool HasNext() const override { return scanner_.HasNext(); }
void ReInit() override { scanner_.Reset(); } void ReInit() override { scanner_.Reset(); }
protected: protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
if (ThreadSafe) { if (ThreadSafe) {
std::lock_guard<std::mutex> guard(*mutex_); std::lock_guard<std::mutex> guard(*mutex_);
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_); *out = framework::ReadFromRecordIO(&scanner_, dev_ctx_);
} else { } else {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_); *out = framework::ReadFromRecordIO(&scanner_, dev_ctx_);
} }
} }
...@@ -66,7 +62,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { ...@@ -66,7 +62,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
const auto& ranks = Attr<std::vector<int>>("ranks"); const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()), static_cast<int>(shape_concat.size()),
"The accumulate of all ranks should be equal to the " "The accumulate of all ranks should be equal to the "
"shape concat's length."); "shape concat's length.");
std::string filename = Attr<std::string>("filename"); std::string filename = Attr<std::string>("filename");
......
...@@ -30,35 +30,33 @@ class ShuffleReader : public framework::DecoratedReader { ...@@ -30,35 +30,33 @@ class ShuffleReader : public framework::DecoratedReader {
std::random_device device; std::random_device device;
seed_ = device(); seed_ = device();
} }
ReadIntoBuffers(); ReloadBuffer();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!HasNext()) { out->clear();
PADDLE_THROW("There is no next data!");
}
if (iteration_pos_ >= buffer_.size()) { if (iteration_pos_ >= buffer_.size()) {
VLOG(10) << "Resetting shuffle buffer"; VLOG(10) << "Resetting shuffle buffer";
ReadIntoBuffers(); ReloadBuffer();
if (buffer_.empty()) {
return;
}
} }
*out = buffer_[iteration_pos_++]; *out = buffer_[iteration_pos_++];
} }
bool HasNext() const override {
return iteration_pos_ < buffer_.size() || reader_->HasNext();
}
private: private:
void ReadIntoBuffers() { void ReloadBuffer() {
buffer_.clear(); buffer_.clear();
buffer_.reserve(buffer_size_); buffer_.reserve(buffer_size_);
iteration_pos_ = 0; iteration_pos_ = 0;
for (size_t i = 0; i < buffer_size_; ++i) { for (size_t i = 0; i < buffer_size_; ++i) {
if (!reader_->HasNext()) { std::vector<framework::LoDTensor> ins;
reader_->ReadNext(&ins);
if (ins.empty()) {
break; break;
} }
buffer_.emplace_back(); buffer_.emplace_back(ins);
reader_->ReadNext(&buffer_.back());
} }
std::mt19937 g(seed_); std::mt19937 g(seed_);
std::shuffle(buffer_.begin(), buffer_.end(), g); std::shuffle(buffer_.begin(), buffer_.end(), g);
......
...@@ -21,67 +21,27 @@ namespace reader { ...@@ -21,67 +21,27 @@ namespace reader {
class ThreadedReader : public framework::DecoratedReader { class ThreadedReader : public framework::DecoratedReader {
public: public:
ThreadedReader(ReaderBase* reader, bool unsafe_mode) ThreadedReader(ReaderBase* reader, bool safe_mode)
: DecoratedReader(reader), unsafe_mode_(unsafe_mode) {} : DecoratedReader(reader), safe_mode_(safe_mode) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (!unsafe_mode_) { reader_->ReadNext(out);
if (!reader_->HasNext()) {
PADDLE_THROW("There is no next data!");
}
reader_->ReadNext(out);
} else {
auto& thread_buffer = thread_buffers_[std::this_thread::get_id()];
if (thread_buffer.empty()) {
PADDLE_THROW(
"thread_buffer is empty! HasNext() must be invoked before "
"ReadNext() in the same thread.");
}
*out = thread_buffer;
thread_buffer.clear();
}
}
bool HasNext() const override {
if (!unsafe_mode_) {
PADDLE_THROW(
"ThreadedReader::HasNext() is disabled when 'unsafe_mode' is false.");
}
std::thread::id thread_id = std::this_thread::get_id();
std::lock_guard<std::mutex> lock(mutex_);
auto& thread_buffer = thread_buffers_[thread_id];
if (thread_buffer.empty() && reader_->HasNext()) {
reader_->ReadNext(&thread_buffer);
}
return !thread_buffer.empty();
} }
void ReInit() override { void ReInit() override {
if (!unsafe_mode_) { if (safe_mode_) {
PADDLE_THROW( PADDLE_THROW(
"ThreadedReader::ReInit() is disabled when 'unsafe_mode' is false."); "ThreadedReader::ReInit() is disabled when 'safe_mode' is true.");
} }
VLOG(5) << "ThreadedReader::ReInit() is invoked! It might be buggy in " VLOG(5) << "ThreadedReader::ReInit() is invoked! It might be buggy in "
"multi-thread environment."; "multi-thread environment.";
reader_->ReInit(); reader_->ReInit();
} }
~ThreadedReader() {
for (auto& p : thread_buffers_) {
if (!p.second.empty()) {
PADDLE_THROW(
"Find an unused data batch in ThreadedReader! Maybe one thread "
"invokes 'HasNext()' without subsequent 'ReadNext()'.");
}
}
}
private: private:
bool unsafe_mode_; bool safe_mode_;
mutable std::mutex mutex_; std::mutex mutex_;
mutable std::unordered_map<std::thread::id, std::vector<framework::LoDTensor>>
thread_buffers_;
}; };
class CreateThreadedReaderOp : public framework::OperatorBase { class CreateThreadedReaderOp : public framework::OperatorBase {
...@@ -98,8 +58,8 @@ class CreateThreadedReaderOp : public framework::OperatorBase { ...@@ -98,8 +58,8 @@ class CreateThreadedReaderOp : public framework::OperatorBase {
} }
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
bool unsafe_mode = Attr<bool>("unsafe_mode"); bool safe_mode = Attr<bool>("safe_mode");
out->Reset(new ThreadedReader(underlying_reader.Get(), unsafe_mode)); out->Reset(new ThreadedReader(underlying_reader.Get(), safe_mode));
} }
}; };
...@@ -107,10 +67,9 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -107,10 +67,9 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
public: public:
CreateThreadedReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) CreateThreadedReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) { : DecoratedReaderMakerBase(op_proto, op_checker) {
AddAttr<bool>("unsafe_mode", AddAttr<bool>("safe_mode",
"When 'unsafe_mode' is false, invoking 'HasNext()' or " "When 'safe_mode' is true, 'ReInit()' is disabled to avoid "
"'ReInit()' is not allowed to avoid unexpected bugs in " "unexpected bugs in multi-thread environment.")
"multi-thread environment.")
.SetDefault(true); .SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
CreateThreadedReader Operator CreateThreadedReader Operator
...@@ -118,13 +77,9 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -118,13 +77,9 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
This operator creates a threaded reader. A threaded reader's This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same 'ReadNext()' can be invoked by several threads at the same
time. time.
When the attribute 'unsafe_mode' is false, the threaded reader's When the attribute 'safe_mode' is true, the threaded reader's
'HasNext()' and 'ReInit()' will be disabled to avoid unexpected 'ReInit()' is disabled to avoid unexpected bugs in multi-thread
bugs in multi-thread environment. If you really need them, you environment.
can enable them by setting 'unsafe_mode' true. In this case,
'HasNext()' returning true only guarantees the safety of
invoking 'ReadNext()' in the same thread. Each thread must
invoke 'HasNext()' and 'ReadNext()' in pairs.
)DOC"); )DOC");
} }
}; };
......
...@@ -30,12 +30,12 @@ class MultiFileReader : public framework::ReaderBase { ...@@ -30,12 +30,12 @@ class MultiFileReader : public framework::ReaderBase {
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNext(std::vector<framework::LoDTensor>* out) override;
bool HasNext() const override;
void ReInit() override; void ReInit() override;
~MultiFileReader() { EndScheduler(); } ~MultiFileReader() { EndScheduler(); }
private: private:
bool HasNext();
void StartNewScheduler(); void StartNewScheduler();
void EndScheduler(); void EndScheduler();
void ScheduleThreadFunc(); void ScheduleThreadFunc();
...@@ -52,16 +52,10 @@ class MultiFileReader : public framework::ReaderBase { ...@@ -52,16 +52,10 @@ class MultiFileReader : public framework::ReaderBase {
}; };
void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) { void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) { out->clear();
PADDLE_THROW("There is no next data!"); if (HasNext()) {
buffer_->Receive(out);
} }
buffer_->Receive(out);
}
bool MultiFileReader::HasNext() const {
while (!buffer_->IsClosed() && !buffer_->CanReceive()) {
}
return buffer_->CanReceive();
} }
void MultiFileReader::ReInit() { void MultiFileReader::ReInit() {
...@@ -69,6 +63,12 @@ void MultiFileReader::ReInit() { ...@@ -69,6 +63,12 @@ void MultiFileReader::ReInit() {
StartNewScheduler(); StartNewScheduler();
} }
bool MultiFileReader::HasNext() {
while (!buffer_->IsClosed() && !buffer_->CanReceive()) {
}
return buffer_->CanReceive();
}
void MultiFileReader::StartNewScheduler() { void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size(); size_t thread_num = prefetchers_.size();
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size()); waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
...@@ -140,9 +140,12 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name, ...@@ -140,9 +140,12 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts."; VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
std::unique_ptr<framework::ReaderBase> reader = std::unique_ptr<framework::ReaderBase> reader =
CreateReaderByFileName(file_name, dims_); CreateReaderByFileName(file_name, dims_);
while (reader->HasNext()) { while (true) {
std::vector<framework::LoDTensor> ins; std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins); reader->ReadNext(&ins);
if (ins.empty()) {
break;
}
try { try {
buffer_->Send(&ins); buffer_->Send(&ins);
} catch (paddle::platform::EnforceNotMet e) { } catch (paddle::platform::EnforceNotMet e) {
......
...@@ -252,7 +252,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -252,7 +252,6 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "") py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("has_next", &framework::ReaderHolder::HasNext)
.def("reset", &framework::ReaderHolder::ReInit); .def("reset", &framework::ReaderHolder::ReInit);
py::class_<Scope>(m, "Scope", "") py::class_<Scope>(m, "Scope", "")
......
...@@ -39,7 +39,7 @@ class RecordIOWriter { ...@@ -39,7 +39,7 @@ class RecordIOWriter {
void CompleteAppendTensor() { void CompleteAppendTensor() {
auto& ctx = auto& ctx =
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); *platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
framework::WriteToRecordIO(writer_, tensors_, ctx); framework::WriteToRecordIO(&writer_, tensors_, ctx);
tensors_.clear(); tensors_.clear();
} }
......
...@@ -236,13 +236,9 @@ def monkey_patch_reader_methods(reader): ...@@ -236,13 +236,9 @@ def monkey_patch_reader_methods(reader):
var = scope.find_var(reader.name) var = scope.find_var(reader.name)
return var.get_reader() return var.get_reader()
def eof():
return not __get_reader__().has_next()
def reset(): def reset():
return __get_reader__().reset() return __get_reader__().reset()
reader.eof = eof
reader.reset = reset reader.reset = reset
reader.stop_gradient = True reader.stop_gradient = True
reader.persistable = True reader.persistable = True
...@@ -299,8 +295,7 @@ def open_recordio_file(filename, ...@@ -299,8 +295,7 @@ def open_recordio_file(filename,
shapes(list): List of tuples which declaring data shapes. shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level. lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type. dtypes(list): List of strs which declaring data type.
pass_num(int): Number of passes to run. After completing the pass_num(int): Number of passes to run.
given number of passes, 'has_next()' will return False.
for_parallel(Bool): Set it as True if you are going to run for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel. subsequent operators in parallel.
...@@ -377,8 +372,7 @@ def open_files(filenames, ...@@ -377,8 +372,7 @@ def open_files(filenames,
dtypes(list): List of strs which declaring data type. dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number. thread_num(int): The maximal concurrent prefetch thread number.
buffer_size(int): The size of prefetch buffer. buffer_size(int): The size of prefetch buffer.
pass_num(int): Number of passes to run. After completing the pass_num(int): Number of passes to run.
given number of passes, 'has_next()' will return False.
for_parallel(Bool): Set it as True if you are going to run for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel. subsequent operators in parallel.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册