提交 d4d946db 编写于 作者: S sneaxiy

update blocking queue

上级 67556e4a
...@@ -28,7 +28,7 @@ class PyReader : public framework::ReaderBase { ...@@ -28,7 +28,7 @@ class PyReader : public framework::ReaderBase {
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
bool success; bool success;
*out = queue_->Dequeue(&success); *out = queue_->Pop(&success);
if (!success) out->clear(); if (!success) out->clear();
} }
...@@ -45,6 +45,10 @@ class CreatePyReaderOp : public framework::OperatorBase { ...@@ -45,6 +45,10 @@ class CreatePyReaderOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) return;
const std::string& queue_name = Input("blocking_queue"); const std::string& queue_name = Input("blocking_queue");
auto* queue_holder_var = scope.FindVar(queue_name); auto* queue_holder_var = scope.FindVar(queue_name);
PADDLE_ENFORCE( PADDLE_ENFORCE(
...@@ -53,8 +57,7 @@ class CreatePyReaderOp : public framework::OperatorBase { ...@@ -53,8 +57,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
queue_name); queue_name);
auto* queue_holder = auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>(); queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new PyReader(queue_holder->GetQueue())); out->Reset(new PyReader(queue_holder->GetQueue()));
} }
}; };
......
...@@ -34,36 +34,33 @@ class LoDTensorBlockingQueue { ...@@ -34,36 +34,33 @@ class LoDTensorBlockingQueue {
private: private:
LoDTensorBlockingQueue(size_t capacity, LoDTensorBlockingQueue(size_t capacity,
const std::vector<framework::DDim>& dims) const std::vector<framework::DDim>& dims)
: dims_(dims) { : queue_(capacity), dims_(dims) {}
queue_.reset(
new BlockingQueue<std::vector<framework::LoDTensor>>(capacity));
}
public: public:
bool Enqueue(const std::vector<framework::LoDTensor>& lod_tensor_vec) { bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
CheckDims(lod_tensor_vec); CheckDims(lod_tensor_vec);
return queue_->Send(lod_tensor_vec); return queue_.Send(lod_tensor_vec);
} }
bool Enqueue(std::vector<framework::LoDTensor>&& lod_tensor_vec) { bool Push(std::vector<framework::LoDTensor>&& lod_tensor_vec) {
CheckDims(lod_tensor_vec); CheckDims(lod_tensor_vec);
return queue_->Send(std::move(lod_tensor_vec)); return queue_.Send(std::move(lod_tensor_vec));
} }
std::vector<framework::LoDTensor> Dequeue(bool* ok = nullptr) { std::vector<framework::LoDTensor> Pop(bool* ok = nullptr) {
std::vector<framework::LoDTensor> lod_tensor_vec; std::vector<framework::LoDTensor> lod_tensor_vec;
bool success = queue_->Receive(&lod_tensor_vec); bool success = queue_.Receive(&lod_tensor_vec);
if (ok != nullptr) *ok = success; if (ok != nullptr) *ok = success;
return lod_tensor_vec; return lod_tensor_vec;
} }
inline size_t Cap() const { return queue_->Cap(); } inline size_t Cap() const { return queue_.Cap(); }
inline size_t Size() const { return queue_->Size(); } inline size_t Size() const { return queue_.Size(); }
inline void Close() { return queue_->Close(); } inline void Close() { return queue_.Close(); }
inline bool IsClosed() const { return queue_->IsClosed(); } inline bool IsClosed() const { return queue_.IsClosed(); }
private: private:
void CheckDims(const std::vector<framework::LoDTensor>& lod_tensor_vec) { void CheckDims(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
...@@ -71,15 +68,16 @@ class LoDTensorBlockingQueue { ...@@ -71,15 +68,16 @@ class LoDTensorBlockingQueue {
"Expect input size is %d but found %s", dims_.size(), "Expect input size is %d but found %s", dims_.size(),
lod_tensor_vec.size()); lod_tensor_vec.size());
for (size_t i = 0; i < dims_.size(); ++i) { for (size_t i = 0; i < dims_.size(); ++i) {
const auto& in_dims = lod_tensor_vec[i].dims(); const auto& in_dims = framework::slice_ddim(
lod_tensor_vec[i].dims(), 1, lod_tensor_vec[i].dims().size());
const auto& expect_dims = const auto& expect_dims =
framework::slice_ddim(dims_[i], 1, dims_[i].size()); framework::slice_ddim(dims_[i], 1, dims_[i].size());
PADDLE_ENFORCE(in_dims == expect_dims, PADDLE_ENFORCE(in_dims == expect_dims,
"Dims of the %d-th input tensor does not match", i); "Dims of the %d-th input tensor do not match", i);
} }
} }
std::unique_ptr<BlockingQueue<std::vector<framework::LoDTensor>>> queue_; BlockingQueue<std::vector<framework::LoDTensor>> queue_;
std::vector<framework::DDim> dims_; std::vector<framework::DDim> dims_;
}; };
...@@ -92,8 +90,6 @@ class LoDTensorBlockingQueueHolder { ...@@ -92,8 +90,6 @@ class LoDTensorBlockingQueueHolder {
queue_.reset(new LoDTensorBlockingQueue(capacity, dims)); queue_.reset(new LoDTensorBlockingQueue(capacity, dims));
} }
inline std::shared_ptr<LoDTensorBlockingQueue> GetQueue() { return queue_; }
inline const std::shared_ptr<LoDTensorBlockingQueue>& GetQueue() const { inline const std::shared_ptr<LoDTensorBlockingQueue>& GetQueue() const {
return queue_; return queue_;
} }
......
...@@ -303,19 +303,16 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -303,19 +303,16 @@ All parameter, weight, gradient are variables in Paddle.
using LoDTensorBlockingQueueHolder = using LoDTensorBlockingQueueHolder =
::paddle::operators::reader::LoDTensorBlockingQueueHolder; ::paddle::operators::reader::LoDTensorBlockingQueueHolder;
py::class_<LoDTensorBlockingQueue>(m, "LoDTensorBlockingQueue", "") py::class_<LoDTensorBlockingQueue>(m, "LoDTensorBlockingQueue", "")
.def("enqueue", .def("push",
[](LoDTensorBlockingQueue &self, [](LoDTensorBlockingQueue &self,
const std::vector<framework::LoDTensor> &lod_tensor_vec) { const std::vector<framework::LoDTensor> &lod_tensor_vec) {
pybind11::gil_scoped_release release; pybind11::gil_scoped_release release;
return self.Enqueue(lod_tensor_vec); return self.Push(lod_tensor_vec);
}) })
.def("size", .def("size", &LoDTensorBlockingQueue::Size)
[](const LoDTensorBlockingQueue &self) { return self.Size(); }) .def("capacity", &LoDTensorBlockingQueue::Cap)
.def("capacity", .def("close", &LoDTensorBlockingQueue::Close)
[](const LoDTensorBlockingQueue &self) { return self.Cap(); }) .def("is_closed", &LoDTensorBlockingQueue::IsClosed);
.def("close", [](LoDTensorBlockingQueue &self) { return self.Close(); })
.def("is_closed",
[](const LoDTensorBlockingQueue &self) { return self.IsClosed(); });
m.def("init_lod_tensor_blocking_queue", m.def("init_lod_tensor_blocking_queue",
[](Variable &var, size_t capacity, [](Variable &var, size_t capacity,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册