提交 6dadb5de 编写于 作者: S sneaxiy

fix iterable=False reset bug, add some logs and polish code, test=develop

上级 60d18a8f
......@@ -402,6 +402,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
} else if (pass->Type() == "set_reader_device_count_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes);
}
VLOG(1) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph);
......
......@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
namespace paddle {
namespace framework {
......@@ -29,6 +30,8 @@ class SetReaderDeviceCountPass : public Pass {
int GetDeviceCount() const;
std::unordered_set<std::string> ReaderOpSet() const;
const Scope *GlobalScope() const;
};
int SetReaderDeviceCountPass::GetDeviceCount() const {
......@@ -40,9 +43,14 @@ std::unordered_set<std::string> SetReaderDeviceCountPass::ReaderOpSet() const {
return {"create_py_reader"};
}
const Scope *SetReaderDeviceCountPass::GlobalScope() const {
return Get<const std::vector<Scope *>>(details::kLocalScopes)[0];
}
void SetReaderDeviceCountPass::ApplyImpl(Graph *graph) const {
auto dev_cnt = GetDeviceCount();
auto reader_ops = ReaderOpSet();
auto scope = GlobalScope();
size_t found_op_num = 0;
for (auto &node : graph->Nodes()) {
......@@ -61,6 +69,18 @@ void SetReaderDeviceCountPass::ApplyImpl(Graph *graph) const {
op_base_attrs["device_index"] = dev_idx;
op_base_attrs["device_count"] = dev_cnt;
auto queue_name = op_handle.GetOp()->Input("blocking_queue");
auto var = scope->FindVar(queue_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound("Blocking queue of DataLoader not found"));
using QueueHolder =
operators::reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder;
if (var->IsType<QueueHolder>()) {
var->GetMutable<QueueHolder>()->GetQueue()->SetDeviceCount(dev_cnt);
}
++found_op_num;
VLOG(10) << "Found op " << op_desc->Type() << " on device " << dev_idx;
}
......
......@@ -117,6 +117,10 @@ class DecoratedReader : public ReaderBase,
~DecoratedReader();
const std::shared_ptr<ReaderBase>& UnderlyingReader() const {
return reader_;
}
protected:
void ShutdownImpl() override {
VLOG(1) << "ShutdownImpl";
......@@ -190,6 +194,8 @@ class ReaderHolder {
return reader_->NeedCheckFeed();
}
void Clear() { reader_.reset(); }
operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; }
private:
......
......@@ -27,12 +27,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
const platform::Place& dev_place) const override {
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
if (out->Get() != nullptr) {
auto* decorated_reader =
dynamic_cast<framework::DecoratedReader*>(out->Get().get());
PADDLE_ENFORCE_NOT_NULL(
decorated_reader,
platform::errors::NotFound("Not inited with DecoratedReader"));
if (decorated_reader->UnderlyingReader() == underlying_reader.Get()) {
return;
}
}
auto place_str = Attr<std::string>("place");
platform::Place place;
if (place_str == "AUTO") {
......@@ -47,6 +55,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place = platform::CUDAPlace(static_cast<int>(num));
}
VLOG(10) << "Create new double buffer reader on " << place;
out->Reset(framework::MakeDecoratedReader<BufferedReader>(underlying_reader,
place, 2));
}
......
......@@ -40,6 +40,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
queue_name);
std::shared_ptr<LoDTensorBlockingQueue> queue;
std::shared_ptr<OrderedMultiDeviceLoDTensorBlockingQueue> ordered_queue;
int dev_idx = -1;
if (queue_holder_var->IsType<LoDTensorBlockingQueueHolder>()) {
queue = queue_holder_var->Get<LoDTensorBlockingQueueHolder>().GetQueue();
} else if (queue_holder_var
......@@ -47,10 +48,9 @@ class CreatePyReaderOp : public framework::OperatorBase {
auto* queue_holder =
queue_holder_var
->GetMutable<OrderedMultiDeviceLoDTensorBlockingQueueHolder>();
auto dev_cnt = Attr<int>("device_count");
auto dev_idx = static_cast<size_t>(Attr<int>("device_index"));
dev_idx = Attr<int>("device_index");
ordered_queue = queue_holder->GetQueue();
ordered_queue->InitOnce(dev_cnt);
ordered_queue->SetDeviceCount(Attr<int>("device_count"));
queue = ordered_queue->GetQueue(dev_idx);
}
......@@ -87,15 +87,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
auto py_reader =
std::make_shared<PyReader>(queue, dims, var_types, need_check_feed);
if (ordered_queue) {
ordered_queue->AddResetMethod([py_reader] {
auto end_readers = py_reader->GetEndPoints();
for (auto* reader : end_readers) {
reader->Shutdown();
}
for (auto* reader : end_readers) {
reader->Start();
}
});
ordered_queue->SetResetMethod(dev_idx, [out] { out->Clear(); });
}
out->Reset(py_reader);
}
......@@ -109,8 +101,9 @@ class CreatePyReaderOpMaker : public FileReaderMakerBase {
AddAttr<int>("device_index", "The device index this reader offers data")
.SetDefault(0);
AddAttr<int>("device_count",
"The total number of devices the reader offers data")
"The total device number this reader offers data")
.SetDefault(1);
AddComment(R"DOC(
......
......@@ -32,6 +32,8 @@ class LoDTensorBlockingQueue {
explicit LoDTensorBlockingQueue(size_t capacity, bool speed_test_mode = false)
: queue_(capacity, speed_test_mode) {}
~LoDTensorBlockingQueue() { VLOG(10) << "Destruct LoDTensorBlockingQueue"; }
bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
return queue_.Send(lod_tensor_vec);
}
......@@ -62,7 +64,7 @@ class LoDTensorBlockingQueue {
inline void Kill() { queue_.Kill(); }
inline bool WaitForInited() { return true; }
inline bool WaitForInited(size_t) { return true; }
private:
BlockingQueue<std::vector<framework::LoDTensor>> queue_;
......@@ -74,47 +76,47 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
bool speed_test_mode = false)
: capacity_(capacity), speed_test_mode_(speed_test_mode) {}
inline bool WaitForInited() {
~OrderedMultiDeviceLoDTensorBlockingQueue() {
VLOG(10) << "Destruct OrderedMultiDeviceLoDTensorBlockingQueue";
}
bool WaitForInited(size_t milliseconds) {
std::unique_lock<std::mutex> lock(init_mutex_);
cv_.wait(lock, [this] { return queues_ != nullptr || is_closing_; });
is_closing_ = false;
return queues_ != nullptr;
return cv_.wait_for(lock, std::chrono::milliseconds(milliseconds),
[this] { return !queues_.empty(); });
}
inline void InitOnce(size_t dev_cnt) {
PADDLE_ENFORCE_GE(dev_cnt, 1, platform::errors::InvalidArgument(
"Device count to init "
"OrderedMultiDeviceLoDTensorBlockingQueue"
" must be larger than 1"));
VLOG(3) << "Ordered queue init start";
void SetDeviceCount(size_t dev_cnt) {
{
std::lock_guard<std::mutex> lock(init_mutex_);
if (queues_) {
PADDLE_ENFORCE_EQ(queues_->size(), dev_cnt,
PADDLE_ENFORCE_GE(dev_cnt, 1,
platform::errors::InvalidArgument(
"Device count to init "
"OrderedMultiDeviceLoDTensorBlockingQueue"
" must be larger than 1"));
if (!queues_.empty()) {
PADDLE_ENFORCE_EQ(queues_.size(), dev_cnt,
platform::errors::InvalidArgument(
"Device count to init queue must be equal"));
} else {
queues_.reset(
new std::vector<std::shared_ptr<LoDTensorBlockingQueue>>(dev_cnt));
for (auto& item : *queues_) {
auto cap = (capacity_ + dev_cnt - 1) / dev_cnt;
item.reset(new LoDTensorBlockingQueue(cap, speed_test_mode_));
}
"queues should be only inited once"));
return;
}
VLOG(1) << "Init queue with size " << dev_cnt;
queues_.resize(dev_cnt);
for (auto& item : queues_) {
auto cap = (capacity_ + dev_cnt - 1) / dev_cnt;
item.reset(new LoDTensorBlockingQueue(cap, speed_test_mode_));
}
}
VLOG(3) << "Ordered queue init finish";
cv_.notify_all();
}
const std::shared_ptr<LoDTensorBlockingQueue>& GetQueue(size_t idx) const {
std::lock_guard<std::mutex> lock(init_mutex_);
PADDLE_ENFORCE_NOT_NULL(queues_,
platform::errors::NotFound(
"Queues must be inited first before getting"));
EnforceIsInited();
PADDLE_ENFORCE_LT(
idx, queues_->size(),
idx, queues_.size(),
platform::errors::OutOfRange("The queue index is out of range"));
return (*queues_)[idx];
return queues_[idx];
}
bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
......@@ -123,65 +125,74 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
inline size_t Size() const {
size_t size = 0;
if (queues_) {
for (auto& item : *queues_) {
size += item->Size();
}
for (auto& item : queues_) {
size += item->Size();
}
return size;
}
inline void Close() {
{
std::lock_guard<std::mutex> lock(init_mutex_);
if (queues_ == nullptr) {
is_closing_ = true;
}
}
cv_.notify_all();
if (queues_) {
for (auto& item : *queues_) {
item->Close();
}
for (auto& item : queues_) {
item->Close();
}
data_index_ = 0;
}
inline void Kill() {
if (queues_) {
for (auto& item : *queues_) {
item->Kill();
}
for (auto& item : queues_) {
item->Kill();
}
}
inline void Reset() {
std::lock_guard<std::mutex> reset_lock(reset_mutex_);
for (auto& method : reset_methods_) {
method();
{
std::lock_guard<std::mutex> reset_lock(reset_mutex_);
for (auto& method : reset_methods_) {
if (method) method();
}
}
auto dev_cnt = queues_.size();
for (auto& item : queues_) {
auto cap = (capacity_ + dev_cnt - 1) / dev_cnt;
item.reset(new LoDTensorBlockingQueue(cap, speed_test_mode_));
}
data_index_ = 0;
}
inline void AddResetMethod(const std::function<void()>& reset_method) {
inline void SetResetMethod(size_t idx,
const std::function<void()>& reset_method) {
std::lock_guard<std::mutex> reset_lock(reset_mutex_);
reset_methods_.emplace_back(reset_method);
EnforceIsInited();
if (reset_methods_.size() <= idx) {
reset_methods_.resize(idx + 1);
}
reset_methods_[idx] = reset_method;
}
private:
const std::shared_ptr<LoDTensorBlockingQueue>& CurQueue() {
return (*queues_)[data_index_.fetch_add(1) % queues_->size()];
EnforceIsInited();
return queues_[data_index_.fetch_add(1) % queues_.size()];
}
private:
void EnforceIsInited() const {
PADDLE_ENFORCE_EQ(queues_.empty(), false,
platform::errors::NotFound("queue has not been inited"));
}
private:
std::unique_ptr<std::vector<std::shared_ptr<LoDTensorBlockingQueue>>> queues_;
std::vector<std::shared_ptr<LoDTensorBlockingQueue>> queues_;
mutable std::atomic<uint64_t> data_index_{0};
size_t dev_cnt_{0};
const size_t capacity_;
const bool speed_test_mode_;
bool is_closed_{false};
std::vector<std::function<void()>> reset_methods_;
mutable std::mutex reset_mutex_;
bool is_closing_{false};
mutable std::mutex init_mutex_;
mutable std::condition_variable cv_;
};
......
......@@ -354,7 +354,7 @@ void BindReader(py::module *module) {
const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places,
bool use_double_buffer) {
queue->InitOnce(dst_places.size());
queue->SetDeviceCount(dst_places.size());
return new MultiDeviceFeedReader<
reader::OrderedMultiDeviceLoDTensorBlockingQueue>(
queue, names, shapes, dtypes, need_check_feed, dst_places,
......
......@@ -347,7 +347,6 @@ class DygraphGeneratorLoader(DataLoaderBase):
self._batch_reader = None
self._places = None
self._feed_list = feed_list
self._keep_order = True
if not capacity:
raise ValueError("Please give value to capacity.")
......@@ -420,7 +419,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
self._dtypes = []
self._need_check_feed = []
self._blocking_queue = core.init_lod_tensor_blocking_queue(
core.Variable(), self._capacity, self._keep_order)
core.Variable(), self._capacity, False)
self._reader = core.create_py_reader(
self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer)
......@@ -635,6 +634,7 @@ class GeneratorLoader(DataLoaderBase):
self._thread = None
self._queue = None
self._feed_list = feed_list
self._exited = False
if not capacity:
raise ValueError("Please give value to capacity.")
self._iterable = iterable
......@@ -798,8 +798,9 @@ class GeneratorLoader(DataLoaderBase):
def _start(self):
def __thread_main__():
try:
if not self._queue.wait_for_inited():
return
while not self._queue.wait_for_inited(1):
if self._exited:
return
for tensors in self._tensor_reader():
array = core.LoDTensorArray()
......@@ -829,10 +830,12 @@ class GeneratorLoader(DataLoaderBase):
def _reset(self):
self._queue.close()
self._exited = True
thread = self._thread
if thread is not None:
thread.join()
self._exited = False
self._reader.reset()
def set_sample_generator(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册