未验证 提交 48f213e5 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #8991 from reyoung/feature/shuffle_reader

Feature/shuffle reader
......@@ -445,15 +445,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
if (var->IsType<ReaderHolder>()) {
return var->Get<ReaderHolder>().shapes();
} else {
PADDLE_THROW(
"Only ReaderHolder support 'GetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
PADDLE_THROW("Only compile time support this method");
}
void SetDim(const std::string& name, const DDim& dim) override {
......@@ -470,15 +462,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
Variable* var = scope_.FindVar(name);
if (var->IsType<ReaderHolder>()) {
var->GetMutable<ReaderHolder>()->set_shapes(dims);
} else {
PADDLE_THROW(
"Only ReaderHolder support 'SetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
PADDLE_THROW("Only compile time support this method");
}
proto::VarType::Type GetVarType(const std::string& name) const override {
......
......@@ -16,14 +16,22 @@
namespace paddle {
namespace framework {
ReaderBase::~ReaderBase() {}
DDim ReaderBase::shape(size_t idx) const {
PADDLE_ENFORCE_LT(
idx, shapes_.size(),
"Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx,
shapes_.size());
return shapes_[idx];
}
FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
void FileReader::ReadNext(std::vector<LoDTensor> *out) {
ReadNextImpl(out);
PADDLE_ENFORCE_EQ(out->size(), dims_.size());
for (size_t i = 0; i < dims_.size(); ++i) {
auto &actual = out->at(i).dims();
auto &expect = dims_[i];
PADDLE_ENFORCE_EQ(actual.size(), expect.size());
for (int j = 0; j < actual.size(); ++j) {
PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1);
}
}
}
} // namespace framework
} // namespace paddle
......@@ -16,40 +16,29 @@
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"
#include <memory>
#include <thread>
#include <vector>
namespace paddle {
namespace framework {
class ReaderBase {
public:
explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) {
PADDLE_ENFORCE(!shapes_.empty());
}
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
virtual void ReInit() = 0;
DDim shape(size_t idx) const;
std::vector<DDim> shapes() const { return shapes_; }
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
virtual bool HasNext() const = 0;
virtual ~ReaderBase() {}
protected:
std::vector<DDim> shapes_;
};
class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& shapes) : ReaderBase(shapes) {}
virtual ~ReaderBase();
};
class DecoratedReader : public ReaderBase {
public:
explicit DecoratedReader(ReaderBase* reader)
: ReaderBase(reader->shapes()), reader_(reader) {
explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_);
}
......@@ -61,6 +50,19 @@ class DecoratedReader : public ReaderBase {
ReaderBase* reader_;
};
class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& dims);
void ReadNext(std::vector<LoDTensor>* out) override;
protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
private:
std::vector<DDim> dims_;
};
// The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables.
class ReaderHolder {
......@@ -78,19 +80,6 @@ class ReaderHolder {
reader_->ReInit();
}
DDim shape(size_t idx) const {
PADDLE_ENFORCE_NOT_NULL(reader_);
return reader_->shape(idx);
}
std::vector<DDim> shapes() const {
PADDLE_ENFORCE_NOT_NULL(reader_);
return reader_->shapes();
}
void set_shapes(const std::vector<DDim>& shapes) {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->set_shapes(shapes);
}
bool HasNext() const { return reader_->HasNext(); }
private:
......
......@@ -24,11 +24,31 @@ static constexpr size_t kDoubleBufferSize = 2;
class DoubleBufferReader : public framework::DecoratedReader {
public:
explicit DoubleBufferReader(ReaderBase* reader)
: DecoratedReader(reader),
buffer_(framework::MakeChannel<std::vector<framework::LoDTensor>>(
kDoubleBufferSize)) {
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
struct Item {
Item() : ctx_(nullptr) {}
std::vector<framework::LoDTensor> payloads_;
platform::DeviceContext* ctx_;
};
explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) {
for (size_t i = 0; i < kDoubleBufferSize; ++i) {
if (platform::is_gpu_place(place_)) {
#ifdef PADDLE_WITH_CUDA
ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_)));
#endif
}
}
start_thread();
}
void start_thread() {
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
std::thread prefetch([this] { PrefetchThreadFunc(); });
prefetch.detach();
}
......@@ -42,7 +62,10 @@ class DoubleBufferReader : public framework::DecoratedReader {
private:
void PrefetchThreadFunc();
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
framework::Channel<Item>* buffer_;
platform::Place place_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
mutable Item local_buffer_;
};
class CreateDoubleBufferReaderOp : public framework::OperatorBase {
......@@ -56,7 +79,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new DoubleBufferReader(underlying_reader.Get()));
auto place_str = Attr<std::string>("place");
platform::Place place;
if (place_str == "CPU") {
place = platform::CPUPlace();
} else {
std::istringstream sin(place_str);
sin.seekg(std::string("CUDA:").size(), std::ios::beg);
size_t num;
sin >> num;
place = platform::CUDAPlace(static_cast<int>(num));
}
out->Reset(new DoubleBufferReader(underlying_reader.Get(), place));
}
};
......@@ -71,44 +107,73 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
It launches another thread to execute the 'underlying reader' asynchronously,
which prevents reading process from blocking subsequent training.
)DOC");
std::unordered_set<std::string> enum_range;
constexpr size_t kMaxCUDADevs = 128;
for (size_t i = 0; i < kMaxCUDADevs; ++i) {
enum_range.insert(string::Sprintf("CUDA:%d", i));
}
enum_range.insert("CPU");
AddAttr<std::string>("place", "The double buffer place, default is CPU")
.SetDefault("CPU")
.InEnum({enum_range});
}
};
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear();
buffer_->Receive(out);
if (local_buffer_.payloads_.empty()) {
buffer_->Receive(&local_buffer_);
}
*out = local_buffer_.payloads_;
local_buffer_.payloads_.clear();
if (local_buffer_.ctx_) {
local_buffer_.ctx_->Wait();
}
}
void DoubleBufferReader::ReInit() {
reader_->ReInit();
buffer_->Close();
// The existing prefetch thread will terminate for the buffer_ is closed.
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
kDoubleBufferSize);
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
prefetch.detach();
start_thread();
}
void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts.";
while (true) {
std::vector<framework::LoDTensor> batch;
reader_->ReadNext(&batch);
if (batch.empty()) {
// EOF
buffer_->Close();
VLOG(5) << "Reached the end of the file. The prefetch thread terminates.";
break;
size_t gpu_ctx_offset = 0;
while (reader_->HasNext()) {
Item batch;
reader_->ReadNext(&batch.payloads_);
if (platform::is_gpu_place(place_)) {
std::vector<framework::LoDTensor> gpu_batch;
auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++];
gpu_ctx_offset %= this->ctxs_.size();
gpu_batch.resize(batch.payloads_.size());
for (size_t i = 0; i < batch.payloads_.size(); ++i) {
framework::TensorCopy(batch.payloads_[i], place_, *gpu_ctx,
&gpu_batch[i]);
gpu_batch[i].set_lod(batch.payloads_[i].lod());
}
batch.ctx_ = gpu_ctx.get();
std::swap(gpu_batch, batch.payloads_);
}
if (!buffer_->Send(&batch)) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread terminates.";
break;
}
}
buffer_->Close();
}
bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); }
bool DoubleBufferReader::HasNext() const {
if (local_buffer_.payloads_.empty()) {
bool ok = buffer_->Receive(&local_buffer_);
return ok;
} else {
return true;
}
}
} // namespace reader
} // namespace operators
......
......@@ -19,11 +19,11 @@ namespace operators {
namespace reader {
template <typename T>
class RandomDataGenerator : public framework::FileReader {
class RandomDataGenerator : public framework::ReaderBase {
public:
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float min,
float max)
: FileReader(shapes), min_(min), max_(max) {
: framework::ReaderBase(), min_(min), max_(max), shapes_(shapes) {
PADDLE_ENFORCE_LE(
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
unsigned int seed = std::random_device()();
......@@ -59,6 +59,7 @@ class RandomDataGenerator : public framework::FileReader {
float max_;
std::minstd_rand engine_;
std::uniform_real_distribution<float> dist_;
std::vector<framework::DDim> shapes_;
};
template <typename T>
......
......@@ -20,21 +20,22 @@ namespace operators {
namespace reader {
class RecordIOFileReader : public framework::FileReader {
public:
RecordIOFileReader(const std::string& filename,
const std::vector<framework::DDim>& shapes)
: FileReader(shapes),
explicit RecordIOFileReader(const std::string& filename,
const std::vector<framework::DDim>& dims)
: FileReader(dims),
scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
}
bool HasNext() const override { return scanner_.HasNext(); }
void ReInit() override { scanner_.Reset(); }
protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
}
private:
recordio::Scanner scanner_;
const platform::DeviceContext& dev_ctx_;
......@@ -54,12 +55,12 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
std::string filename = Attr<std::string>("filename");
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader(filename, shapes));
out->Reset(
new RecordIOFileReader(filename, RestoreShapes(shape_concat, ranks)));
}
};
......@@ -85,3 +86,5 @@ namespace reader = paddle::operators::reader;
REGISTER_FILE_READER_OPERATOR(create_recordio_file_reader,
reader::CreateRecordIOReaderOp,
reader::CreateRecordIOReaderOpMaker);
REGISTER_FILE_READER(recordio, reader::RecordIOFileReader);
......@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <random>
#include "glog/logging.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
......@@ -20,43 +23,53 @@ namespace reader {
class ShuffleReader : public framework::DecoratedReader {
public:
ShuffleReader(ReaderBase* reader, int buffer_size)
: DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) {
buffer_.reserve(buffer_size);
ShuffleReader(ReaderBase* reader, size_t buffer_size, size_t seed = 0)
: DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) {
VLOG(10) << "Create shuffle reader of " << reader_;
if (seed_ == 0) {
std::random_device device;
seed_ = device();
}
ReadIntoBuffers();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (iteration_pos_ >= buffer_.size()) {
VLOG(10) << "Resetting shuffle buffer";
ReadIntoBuffers();
}
*out = buffer_[iteration_pos_++];
}
private:
int buffer_size_;
std::vector<std::vector<framework::LoDTensor>> buffer_;
size_t iteration_pos_;
};
bool HasNext() const override {
return iteration_pos_ < buffer_.size() || reader_->HasNext();
}
void ShuffleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (iteration_pos_ >= buffer_.size()) {
// Reload buffer with new data
private:
void ReadIntoBuffers() {
buffer_.clear();
buffer_.reserve(buffer_size_);
for (int i = 0; i < buffer_size_; ++i) {
buffer_.push_back(std::vector<framework::LoDTensor>());
reader_->ReadNext(&buffer_.back());
if (buffer_.back().empty()) {
buffer_.pop_back();
iteration_pos_ = 0;
PADDLE_ENFORCE(reader_->HasNext());
for (size_t i = 0; i < buffer_size_; ++i) {
if (!reader_->HasNext()) {
break;
}
buffer_.emplace_back();
reader_->ReadNext(&buffer_.back());
}
// TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be
// optimize.
std::random_shuffle(buffer_.begin(), buffer_.end());
iteration_pos_ = 0;
std::mt19937 g(seed_);
std::shuffle(buffer_.begin(), buffer_.end(), g);
seed_ = g(); // update seed_;
VLOG(10) << "random buffer size = " << buffer_.size();
}
out->clear();
if (!buffer_.empty()) {
std::swap(*out, buffer_[iteration_pos_++]);
}
// if buffer_ is empty, the 'out' will return as an empty vector.
}
size_t buffer_size_;
std::vector<std::vector<framework::LoDTensor>> buffer_;
size_t iteration_pos_;
size_t seed_;
};
class CreateShuffleReaderOp : public framework::OperatorBase {
public:
......@@ -67,10 +80,10 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(
new ShuffleReader(underlying_reader.Get(), Attr<int>("buffer_size")));
auto& var = detail::Ref(scope.FindVar(Output("Out")));
var.GetMutable<framework::ReaderHolder>()->Reset(
new ShuffleReader(underlying_reader.Get(),
static_cast<size_t>(Attr<int>("buffer_size"))));
}
};
......
......@@ -31,6 +31,11 @@ std::vector<framework::DDim> RestoreShapes(const std::vector<int>& shape_concat,
return res;
}
std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
static std::unordered_map<std::string, FileReaderCreator> regs;
return regs;
}
FileReaderMakerBase::FileReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker)
......
......@@ -21,6 +21,20 @@ namespace paddle {
namespace operators {
namespace reader {
using FileReaderCreator = std::function<framework::ReaderBase*(
const std::string&, const std::vector<framework::DDim>&)>;
std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
template <typename Reader>
int RegisterFileReader(const std::string& filetype) {
FileReaderRegistry()[filetype] = [](
const std::string& fn, const std::vector<paddle::framework::DDim>& dim) {
return new Reader(fn, dim);
};
return 0;
}
extern std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks);
......@@ -73,3 +87,15 @@ class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker {
paddle::operators::reader::DecoratedReaderInferShape, \
paddle::framework::EmptyGradOpMaker, \
paddle::operators::reader::DecoratedReaderInferVarType)
#define REGISTER_FILE_READER(_filetype, _reader) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
_reg_file_reader_##_filetype, \
"Must use REGISTER_FILE_READER in global namespace"); \
int TouchFileReader##_filetype() { return 0; } \
int _reg_file_reader_entry_##filetype = \
paddle::operators::reader::RegisterFileReader<_reader>(#_filetype)
#define USE_FILE_READER(filetype) \
extern int TouchFileReader##filetype(); \
static int _use_##filetype = TouchFileReader##filetype()
......@@ -21,7 +21,7 @@ from ..executor import global_scope
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'read_file'
'read_file', 'create_shuffle_reader', 'create_double_buffer_reader'
]
......@@ -245,6 +245,8 @@ def monkey_patch_reader_methods(reader):
reader.eof = eof
reader.reset = reset
reader.stop_gradient = True
reader.persistable = True
return reader
......@@ -285,6 +287,33 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var)
def __create_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
startup_blk.append_op(
type=op_type,
inputs={'UnderlyingReader': reader},
outputs={'Out': [startup_var]},
attrs=attrs)
startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(),
startup_var)
def create_shuffle_reader(reader, buffer_size):
return __create_decorated_reader__('create_shuffle_reader', reader,
{'buffer_size': int(buffer_size)})
def create_double_buffer_reader(reader, place=None):
attrs = dict()
if place is not None:
attrs['place'] = str(place).upper()
return __create_decorated_reader__('create_double_buffer_reader', reader,
attrs)
def read_file(file_obj):
helper = LayerHelper('read_file')
out = [
......
......@@ -36,6 +36,7 @@ def convert_reader_to_recordio_file(
feed_order=None):
if feed_order is None:
feed_order = feeder.feed_names
counter = 0
with create_recordio_writer(filename, compressor,
max_num_records) as writer:
for batch in reader_creator():
......@@ -43,3 +44,5 @@ def convert_reader_to_recordio_file(
for each in feed_order:
writer.append_tensor(res[each])
writer.complete_append_tensor()
counter += 1
return counter
......@@ -13,9 +13,10 @@
# limitations under the License.
import unittest
import paddle.fluid as fluid
import paddle.v2.dataset.mnist as mnist
import paddle.v2 as paddle
import paddle.v2.dataset.mnist as mnist
class TestRecordIO(unittest.TestCase):
......@@ -31,10 +32,10 @@ class TestRecordIO(unittest.TestCase):
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file(
self.num_batches = fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist.recordio', reader, feeder)
def test_main(self):
def test_main(self, decorator_callback=None):
# use new program
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_file = fluid.layers.open_recordio_file(
......@@ -42,6 +43,8 @@ class TestRecordIO(unittest.TestCase):
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
if decorator_callback is not None:
data_file = decorator_callback(data_file)
img, label = fluid.layers.read_file(data_file)
hidden = fluid.layers.fc(input=img, size=100, act='tanh')
......@@ -51,14 +54,28 @@ class TestRecordIO(unittest.TestCase):
fluid.optimizer.Adam(learning_rate=1e-3).minimize(avg_loss)
exe = fluid.Executor(fluid.CPUPlace())
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
avg_loss_np = []
# train a pass
batch_id = 0
while not data_file.eof():
tmp, = exe.run(fetch_list=[avg_loss])
avg_loss_np.append(tmp)
batch_id += 1
data_file.reset()
self.assertEqual(batch_id, self.num_batches)
self.assertLess(avg_loss_np[-1], avg_loss_np[0])
def test_shuffle_reader(self):
self.test_main(decorator_callback=lambda reader: fluid.layers.create_shuffle_reader(reader, buffer_size=200))
def test_double_buffer_reader(self):
self.test_main(decorator_callback=lambda reader: fluid.layers.create_double_buffer_reader(reader,
place='cuda:0' if fluid.core.is_compiled_with_cuda() else 'cpu'))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册