未验证 提交 48f213e5 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #8991 from reyoung/feature/shuffle_reader

Feature/shuffle reader
...@@ -445,15 +445,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -445,15 +445,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
std::vector<DDim> GetRepeatedDims(const std::string& name) const override { std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
Variable* var = scope_.FindVar(name); PADDLE_THROW("Only compile time support this method");
if (var->IsType<ReaderHolder>()) {
return var->Get<ReaderHolder>().shapes();
} else {
PADDLE_THROW(
"Only ReaderHolder support 'GetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
} }
void SetDim(const std::string& name, const DDim& dim) override { void SetDim(const std::string& name, const DDim& dim) override {
...@@ -470,15 +462,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -470,15 +462,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
void SetRepeatedDims(const std::string& name, void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override { const std::vector<DDim>& dims) override {
Variable* var = scope_.FindVar(name); PADDLE_THROW("Only compile time support this method");
if (var->IsType<ReaderHolder>()) {
var->GetMutable<ReaderHolder>()->set_shapes(dims);
} else {
PADDLE_THROW(
"Only ReaderHolder support 'SetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
} }
proto::VarType::Type GetVarType(const std::string& name) const override { proto::VarType::Type GetVarType(const std::string& name) const override {
......
...@@ -16,14 +16,22 @@ ...@@ -16,14 +16,22 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
ReaderBase::~ReaderBase() {}
DDim ReaderBase::shape(size_t idx) const { FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
PADDLE_ENFORCE_LT(
idx, shapes_.size(), void FileReader::ReadNext(std::vector<LoDTensor> *out) {
"Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx, ReadNextImpl(out);
shapes_.size()); PADDLE_ENFORCE_EQ(out->size(), dims_.size());
return shapes_[idx]; for (size_t i = 0; i < dims_.size(); ++i) {
} auto &actual = out->at(i).dims();
auto &expect = dims_[i];
PADDLE_ENFORCE_EQ(actual.size(), expect.size());
for (int j = 0; j < actual.size(); ++j) {
PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1);
}
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -16,40 +16,29 @@ ...@@ -16,40 +16,29 @@
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"
#include <memory>
#include <thread>
#include <vector>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class ReaderBase { class ReaderBase {
public: public:
explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) {
PADDLE_ENFORCE(!shapes_.empty());
}
virtual void ReadNext(std::vector<LoDTensor>* out) = 0; virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
virtual void ReInit() = 0; virtual void ReInit() = 0;
DDim shape(size_t idx) const;
std::vector<DDim> shapes() const { return shapes_; }
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
virtual bool HasNext() const = 0; virtual bool HasNext() const = 0;
virtual ~ReaderBase() {} virtual ~ReaderBase();
protected:
std::vector<DDim> shapes_;
};
class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& shapes) : ReaderBase(shapes) {}
}; };
class DecoratedReader : public ReaderBase { class DecoratedReader : public ReaderBase {
public: public:
explicit DecoratedReader(ReaderBase* reader) explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) {
: ReaderBase(reader->shapes()), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
} }
...@@ -61,6 +50,19 @@ class DecoratedReader : public ReaderBase { ...@@ -61,6 +50,19 @@ class DecoratedReader : public ReaderBase {
ReaderBase* reader_; ReaderBase* reader_;
}; };
class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& dims);
void ReadNext(std::vector<LoDTensor>* out) override;
protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
private:
std::vector<DDim> dims_;
};
// The ReaderHolder is used as reader' unified wrapper, // The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables. // making it easier to access different type reader in Variables.
class ReaderHolder { class ReaderHolder {
...@@ -78,19 +80,6 @@ class ReaderHolder { ...@@ -78,19 +80,6 @@ class ReaderHolder {
reader_->ReInit(); reader_->ReInit();
} }
DDim shape(size_t idx) const {
PADDLE_ENFORCE_NOT_NULL(reader_);
return reader_->shape(idx);
}
std::vector<DDim> shapes() const {
PADDLE_ENFORCE_NOT_NULL(reader_);
return reader_->shapes();
}
void set_shapes(const std::vector<DDim>& shapes) {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->set_shapes(shapes);
}
bool HasNext() const { return reader_->HasNext(); } bool HasNext() const { return reader_->HasNext(); }
private: private:
......
...@@ -24,11 +24,31 @@ static constexpr size_t kDoubleBufferSize = 2; ...@@ -24,11 +24,31 @@ static constexpr size_t kDoubleBufferSize = 2;
class DoubleBufferReader : public framework::DecoratedReader { class DoubleBufferReader : public framework::DecoratedReader {
public: public:
explicit DoubleBufferReader(ReaderBase* reader) struct Item {
: DecoratedReader(reader), Item() : ctx_(nullptr) {}
buffer_(framework::MakeChannel<std::vector<framework::LoDTensor>>(
kDoubleBufferSize)) { std::vector<framework::LoDTensor> payloads_;
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this); platform::DeviceContext* ctx_;
};
explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) {
for (size_t i = 0; i < kDoubleBufferSize; ++i) {
if (platform::is_gpu_place(place_)) {
#ifdef PADDLE_WITH_CUDA
ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_)));
#endif
}
}
start_thread();
}
void start_thread() {
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
std::thread prefetch([this] { PrefetchThreadFunc(); });
prefetch.detach(); prefetch.detach();
} }
...@@ -42,7 +62,10 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -42,7 +62,10 @@ class DoubleBufferReader : public framework::DecoratedReader {
private: private:
void PrefetchThreadFunc(); void PrefetchThreadFunc();
framework::Channel<std::vector<framework::LoDTensor>>* buffer_; framework::Channel<Item>* buffer_;
platform::Place place_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
mutable Item local_buffer_;
}; };
class CreateDoubleBufferReaderOp : public framework::OperatorBase { class CreateDoubleBufferReaderOp : public framework::OperatorBase {
...@@ -56,7 +79,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { ...@@ -56,7 +79,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new DoubleBufferReader(underlying_reader.Get()));
auto place_str = Attr<std::string>("place");
platform::Place place;
if (place_str == "CPU") {
place = platform::CPUPlace();
} else {
std::istringstream sin(place_str);
sin.seekg(std::string("CUDA:").size(), std::ios::beg);
size_t num;
sin >> num;
place = platform::CUDAPlace(static_cast<int>(num));
}
out->Reset(new DoubleBufferReader(underlying_reader.Get(), place));
} }
}; };
...@@ -71,44 +107,73 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -71,44 +107,73 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
It launches another thread to execute the 'underlying reader' asynchronously, It launches another thread to execute the 'underlying reader' asynchronously,
which prevents reading process from blocking subsequent training. which prevents reading process from blocking subsequent training.
)DOC"); )DOC");
std::unordered_set<std::string> enum_range;
constexpr size_t kMaxCUDADevs = 128;
for (size_t i = 0; i < kMaxCUDADevs; ++i) {
enum_range.insert(string::Sprintf("CUDA:%d", i));
}
enum_range.insert("CPU");
AddAttr<std::string>("place", "The double buffer place, default is CPU")
.SetDefault("CPU")
.InEnum({enum_range});
} }
}; };
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear(); if (local_buffer_.payloads_.empty()) {
buffer_->Receive(out); buffer_->Receive(&local_buffer_);
}
*out = local_buffer_.payloads_;
local_buffer_.payloads_.clear();
if (local_buffer_.ctx_) {
local_buffer_.ctx_->Wait();
}
} }
void DoubleBufferReader::ReInit() { void DoubleBufferReader::ReInit() {
reader_->ReInit(); reader_->ReInit();
buffer_->Close(); buffer_->Close();
// The existing prefetch thread will terminate for the buffer_ is closed. start_thread();
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
kDoubleBufferSize);
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
prefetch.detach();
} }
void DoubleBufferReader::PrefetchThreadFunc() { void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts."; VLOG(5) << "A new prefetch thread starts.";
while (true) { size_t gpu_ctx_offset = 0;
std::vector<framework::LoDTensor> batch; while (reader_->HasNext()) {
reader_->ReadNext(&batch); Item batch;
if (batch.empty()) { reader_->ReadNext(&batch.payloads_);
// EOF if (platform::is_gpu_place(place_)) {
buffer_->Close(); std::vector<framework::LoDTensor> gpu_batch;
VLOG(5) << "Reached the end of the file. The prefetch thread terminates."; auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++];
break; gpu_ctx_offset %= this->ctxs_.size();
gpu_batch.resize(batch.payloads_.size());
for (size_t i = 0; i < batch.payloads_.size(); ++i) {
framework::TensorCopy(batch.payloads_[i], place_, *gpu_ctx,
&gpu_batch[i]);
gpu_batch[i].set_lod(batch.payloads_[i].lod());
} }
batch.ctx_ = gpu_ctx.get();
std::swap(gpu_batch, batch.payloads_);
}
if (!buffer_->Send(&batch)) { if (!buffer_->Send(&batch)) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The " VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread terminates."; "prefetch thread terminates.";
break; break;
} }
} }
buffer_->Close();
} }
bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); } bool DoubleBufferReader::HasNext() const {
if (local_buffer_.payloads_.empty()) {
bool ok = buffer_->Receive(&local_buffer_);
return ok;
} else {
return true;
}
}
} // namespace reader } // namespace reader
} // namespace operators } // namespace operators
......
...@@ -19,11 +19,11 @@ namespace operators { ...@@ -19,11 +19,11 @@ namespace operators {
namespace reader { namespace reader {
template <typename T> template <typename T>
class RandomDataGenerator : public framework::FileReader { class RandomDataGenerator : public framework::ReaderBase {
public: public:
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float min, RandomDataGenerator(const std::vector<framework::DDim>& shapes, float min,
float max) float max)
: FileReader(shapes), min_(min), max_(max) { : framework::ReaderBase(), min_(min), max_(max), shapes_(shapes) {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max); min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
unsigned int seed = std::random_device()(); unsigned int seed = std::random_device()();
...@@ -59,6 +59,7 @@ class RandomDataGenerator : public framework::FileReader { ...@@ -59,6 +59,7 @@ class RandomDataGenerator : public framework::FileReader {
float max_; float max_;
std::minstd_rand engine_; std::minstd_rand engine_;
std::uniform_real_distribution<float> dist_; std::uniform_real_distribution<float> dist_;
std::vector<framework::DDim> shapes_;
}; };
template <typename T> template <typename T>
......
...@@ -20,21 +20,22 @@ namespace operators { ...@@ -20,21 +20,22 @@ namespace operators {
namespace reader { namespace reader {
class RecordIOFileReader : public framework::FileReader { class RecordIOFileReader : public framework::FileReader {
public: public:
RecordIOFileReader(const std::string& filename, explicit RecordIOFileReader(const std::string& filename,
const std::vector<framework::DDim>& shapes) const std::vector<framework::DDim>& dims)
: FileReader(shapes), : FileReader(dims),
scanner_(filename), scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get( dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) {} platform::CPUPlace())) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
}
bool HasNext() const override { return scanner_.HasNext(); } bool HasNext() const override { return scanner_.HasNext(); }
void ReInit() override { scanner_.Reset(); } void ReInit() override { scanner_.Reset(); }
protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
}
private: private:
recordio::Scanner scanner_; recordio::Scanner scanner_;
const platform::DeviceContext& dev_ctx_; const platform::DeviceContext& dev_ctx_;
...@@ -54,12 +55,12 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { ...@@ -54,12 +55,12 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
int(shape_concat.size()), int(shape_concat.size()),
"The accumulate of all ranks should be equal to the " "The accumulate of all ranks should be equal to the "
"shape concat's length."); "shape concat's length.");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
std::string filename = Attr<std::string>("filename"); std::string filename = Attr<std::string>("filename");
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader(filename, shapes)); out->Reset(
new RecordIOFileReader(filename, RestoreShapes(shape_concat, ranks)));
} }
}; };
...@@ -85,3 +86,5 @@ namespace reader = paddle::operators::reader; ...@@ -85,3 +86,5 @@ namespace reader = paddle::operators::reader;
REGISTER_FILE_READER_OPERATOR(create_recordio_file_reader, REGISTER_FILE_READER_OPERATOR(create_recordio_file_reader,
reader::CreateRecordIOReaderOp, reader::CreateRecordIOReaderOp,
reader::CreateRecordIOReaderOpMaker); reader::CreateRecordIOReaderOpMaker);
REGISTER_FILE_READER(recordio, reader::RecordIOFileReader);
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <random>
#include "glog/logging.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle { namespace paddle {
...@@ -20,43 +23,53 @@ namespace reader { ...@@ -20,43 +23,53 @@ namespace reader {
class ShuffleReader : public framework::DecoratedReader { class ShuffleReader : public framework::DecoratedReader {
public: public:
ShuffleReader(ReaderBase* reader, int buffer_size) ShuffleReader(ReaderBase* reader, size_t buffer_size, size_t seed = 0)
: DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) { : DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) {
buffer_.reserve(buffer_size); VLOG(10) << "Create shuffle reader of " << reader_;
if (seed_ == 0) {
std::random_device device;
seed_ = device();
}
ReadIntoBuffers();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (iteration_pos_ >= buffer_.size()) {
VLOG(10) << "Resetting shuffle buffer";
ReadIntoBuffers();
}
*out = buffer_[iteration_pos_++];
}
private: bool HasNext() const override {
int buffer_size_; return iteration_pos_ < buffer_.size() || reader_->HasNext();
std::vector<std::vector<framework::LoDTensor>> buffer_; }
size_t iteration_pos_;
};
void ShuffleReader::ReadNext(std::vector<framework::LoDTensor>* out) { private:
if (iteration_pos_ >= buffer_.size()) { void ReadIntoBuffers() {
// Reload buffer with new data
buffer_.clear(); buffer_.clear();
buffer_.reserve(buffer_size_); buffer_.reserve(buffer_size_);
for (int i = 0; i < buffer_size_; ++i) { iteration_pos_ = 0;
buffer_.push_back(std::vector<framework::LoDTensor>()); PADDLE_ENFORCE(reader_->HasNext());
reader_->ReadNext(&buffer_.back()); for (size_t i = 0; i < buffer_size_; ++i) {
if (buffer_.back().empty()) { if (!reader_->HasNext()) {
buffer_.pop_back();
break; break;
} }
buffer_.emplace_back();
reader_->ReadNext(&buffer_.back());
} }
// TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be std::mt19937 g(seed_);
// optimize. std::shuffle(buffer_.begin(), buffer_.end(), g);
std::random_shuffle(buffer_.begin(), buffer_.end()); seed_ = g(); // update seed_;
iteration_pos_ = 0; VLOG(10) << "random buffer size = " << buffer_.size();
}
out->clear();
if (!buffer_.empty()) {
std::swap(*out, buffer_[iteration_pos_++]);
} }
// if buffer_ is empty, the 'out' will return as an empty vector.
} size_t buffer_size_;
std::vector<std::vector<framework::LoDTensor>> buffer_;
size_t iteration_pos_;
size_t seed_;
};
class CreateShuffleReaderOp : public framework::OperatorBase { class CreateShuffleReaderOp : public framework::OperatorBase {
public: public:
...@@ -67,10 +80,10 @@ class CreateShuffleReaderOp : public framework::OperatorBase { ...@@ -67,10 +80,10 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out")) auto& var = detail::Ref(scope.FindVar(Output("Out")));
->template GetMutable<framework::ReaderHolder>(); var.GetMutable<framework::ReaderHolder>()->Reset(
out->Reset( new ShuffleReader(underlying_reader.Get(),
new ShuffleReader(underlying_reader.Get(), Attr<int>("buffer_size"))); static_cast<size_t>(Attr<int>("buffer_size"))));
} }
}; };
......
...@@ -31,6 +31,11 @@ std::vector<framework::DDim> RestoreShapes(const std::vector<int>& shape_concat, ...@@ -31,6 +31,11 @@ std::vector<framework::DDim> RestoreShapes(const std::vector<int>& shape_concat,
return res; return res;
} }
std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
static std::unordered_map<std::string, FileReaderCreator> regs;
return regs;
}
FileReaderMakerBase::FileReaderMakerBase( FileReaderMakerBase::FileReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto, framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
......
...@@ -21,6 +21,20 @@ namespace paddle { ...@@ -21,6 +21,20 @@ namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
using FileReaderCreator = std::function<framework::ReaderBase*(
const std::string&, const std::vector<framework::DDim>&)>;
std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
template <typename Reader>
int RegisterFileReader(const std::string& filetype) {
FileReaderRegistry()[filetype] = [](
const std::string& fn, const std::vector<paddle::framework::DDim>& dim) {
return new Reader(fn, dim);
};
return 0;
}
extern std::vector<framework::DDim> RestoreShapes( extern std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks); const std::vector<int>& shape_concat, const std::vector<int>& ranks);
...@@ -73,3 +87,15 @@ class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker { ...@@ -73,3 +87,15 @@ class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker {
paddle::operators::reader::DecoratedReaderInferShape, \ paddle::operators::reader::DecoratedReaderInferShape, \
paddle::framework::EmptyGradOpMaker, \ paddle::framework::EmptyGradOpMaker, \
paddle::operators::reader::DecoratedReaderInferVarType) paddle::operators::reader::DecoratedReaderInferVarType)
#define REGISTER_FILE_READER(_filetype, _reader) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
_reg_file_reader_##_filetype, \
"Must use REGISTER_FILE_READER in global namespace"); \
int TouchFileReader##_filetype() { return 0; } \
int _reg_file_reader_entry_##filetype = \
paddle::operators::reader::RegisterFileReader<_reader>(#_filetype)
#define USE_FILE_READER(filetype) \
extern int TouchFileReader##filetype(); \
static int _use_##filetype = TouchFileReader##filetype()
...@@ -21,7 +21,7 @@ from ..executor import global_scope ...@@ -21,7 +21,7 @@ from ..executor import global_scope
__all__ = [ __all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'read_file' 'read_file', 'create_shuffle_reader', 'create_double_buffer_reader'
] ]
...@@ -245,6 +245,8 @@ def monkey_patch_reader_methods(reader): ...@@ -245,6 +245,8 @@ def monkey_patch_reader_methods(reader):
reader.eof = eof reader.eof = eof
reader.reset = reset reader.reset = reset
reader.stop_gradient = True
reader.persistable = True
return reader return reader
...@@ -285,6 +287,33 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes): ...@@ -285,6 +287,33 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var) startup_var)
def __create_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
startup_blk.append_op(
type=op_type,
inputs={'UnderlyingReader': reader},
outputs={'Out': [startup_var]},
attrs=attrs)
startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(),
startup_var)
def create_shuffle_reader(reader, buffer_size):
return __create_decorated_reader__('create_shuffle_reader', reader,
{'buffer_size': int(buffer_size)})
def create_double_buffer_reader(reader, place=None):
attrs = dict()
if place is not None:
attrs['place'] = str(place).upper()
return __create_decorated_reader__('create_double_buffer_reader', reader,
attrs)
def read_file(file_obj): def read_file(file_obj):
helper = LayerHelper('read_file') helper = LayerHelper('read_file')
out = [ out = [
......
...@@ -36,6 +36,7 @@ def convert_reader_to_recordio_file( ...@@ -36,6 +36,7 @@ def convert_reader_to_recordio_file(
feed_order=None): feed_order=None):
if feed_order is None: if feed_order is None:
feed_order = feeder.feed_names feed_order = feeder.feed_names
counter = 0
with create_recordio_writer(filename, compressor, with create_recordio_writer(filename, compressor,
max_num_records) as writer: max_num_records) as writer:
for batch in reader_creator(): for batch in reader_creator():
...@@ -43,3 +44,5 @@ def convert_reader_to_recordio_file( ...@@ -43,3 +44,5 @@ def convert_reader_to_recordio_file(
for each in feed_order: for each in feed_order:
writer.append_tensor(res[each]) writer.append_tensor(res[each])
writer.complete_append_tensor() writer.complete_append_tensor()
counter += 1
return counter
...@@ -13,9 +13,10 @@ ...@@ -13,9 +13,10 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.v2.dataset.mnist as mnist
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.dataset.mnist as mnist
class TestRecordIO(unittest.TestCase): class TestRecordIO(unittest.TestCase):
...@@ -31,10 +32,10 @@ class TestRecordIO(unittest.TestCase): ...@@ -31,10 +32,10 @@ class TestRecordIO(unittest.TestCase):
name='label', shape=[1], dtype='int64'), name='label', shape=[1], dtype='int64'),
], ],
place=fluid.CPUPlace()) place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file( self.num_batches = fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist.recordio', reader, feeder) './mnist.recordio', reader, feeder)
def test_main(self): def test_main(self, decorator_callback=None):
# use new program # use new program
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data_file = fluid.layers.open_recordio_file( data_file = fluid.layers.open_recordio_file(
...@@ -42,6 +43,8 @@ class TestRecordIO(unittest.TestCase): ...@@ -42,6 +43,8 @@ class TestRecordIO(unittest.TestCase):
shapes=[[-1, 784], [-1, 1]], shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64']) dtypes=['float32', 'int64'])
if decorator_callback is not None:
data_file = decorator_callback(data_file)
img, label = fluid.layers.read_file(data_file) img, label = fluid.layers.read_file(data_file)
hidden = fluid.layers.fc(input=img, size=100, act='tanh') hidden = fluid.layers.fc(input=img, size=100, act='tanh')
...@@ -51,14 +54,28 @@ class TestRecordIO(unittest.TestCase): ...@@ -51,14 +54,28 @@ class TestRecordIO(unittest.TestCase):
fluid.optimizer.Adam(learning_rate=1e-3).minimize(avg_loss) fluid.optimizer.Adam(learning_rate=1e-3).minimize(avg_loss)
exe = fluid.Executor(fluid.CPUPlace()) if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
avg_loss_np = [] avg_loss_np = []
# train a pass # train a pass
batch_id = 0
while not data_file.eof(): while not data_file.eof():
tmp, = exe.run(fetch_list=[avg_loss]) tmp, = exe.run(fetch_list=[avg_loss])
avg_loss_np.append(tmp) avg_loss_np.append(tmp)
batch_id += 1
data_file.reset() data_file.reset()
self.assertEqual(batch_id, self.num_batches)
self.assertLess(avg_loss_np[-1], avg_loss_np[0]) self.assertLess(avg_loss_np[-1], avg_loss_np[0])
def test_shuffle_reader(self):
self.test_main(decorator_callback=lambda reader: fluid.layers.create_shuffle_reader(reader, buffer_size=200))
def test_double_buffer_reader(self):
self.test_main(decorator_callback=lambda reader: fluid.layers.create_double_buffer_reader(reader,
place='cuda:0' if fluid.core.is_compiled_with_cuda() else 'cpu'))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册