提交 d4d946db 编写于 作者: S sneaxiy

update blocking queue

上级 67556e4a
......@@ -28,7 +28,7 @@ class PyReader : public framework::ReaderBase {
void ReadNext(std::vector<framework::LoDTensor>* out) override {
bool success;
*out = queue_->Dequeue(&success);
*out = queue_->Pop(&success);
if (!success) out->clear();
}
......@@ -45,6 +45,10 @@ class CreatePyReaderOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) return;
const std::string& queue_name = Input("blocking_queue");
auto* queue_holder_var = scope.FindVar(queue_name);
PADDLE_ENFORCE(
......@@ -53,8 +57,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
queue_name);
auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new PyReader(queue_holder->GetQueue()));
}
};
......
......@@ -34,36 +34,33 @@ class LoDTensorBlockingQueue {
private:
LoDTensorBlockingQueue(size_t capacity,
const std::vector<framework::DDim>& dims)
: dims_(dims) {
queue_.reset(
new BlockingQueue<std::vector<framework::LoDTensor>>(capacity));
}
: queue_(capacity), dims_(dims) {}
public:
bool Enqueue(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
CheckDims(lod_tensor_vec);
return queue_->Send(lod_tensor_vec);
return queue_.Send(lod_tensor_vec);
}
bool Enqueue(std::vector<framework::LoDTensor>&& lod_tensor_vec) {
bool Push(std::vector<framework::LoDTensor>&& lod_tensor_vec) {
CheckDims(lod_tensor_vec);
return queue_->Send(std::move(lod_tensor_vec));
return queue_.Send(std::move(lod_tensor_vec));
}
std::vector<framework::LoDTensor> Dequeue(bool* ok = nullptr) {
std::vector<framework::LoDTensor> Pop(bool* ok = nullptr) {
std::vector<framework::LoDTensor> lod_tensor_vec;
bool success = queue_->Receive(&lod_tensor_vec);
bool success = queue_.Receive(&lod_tensor_vec);
if (ok != nullptr) *ok = success;
return lod_tensor_vec;
}
inline size_t Cap() const { return queue_->Cap(); }
inline size_t Cap() const { return queue_.Cap(); }
inline size_t Size() const { return queue_->Size(); }
inline size_t Size() const { return queue_.Size(); }
inline void Close() { return queue_->Close(); }
inline void Close() { return queue_.Close(); }
inline bool IsClosed() const { return queue_->IsClosed(); }
inline bool IsClosed() const { return queue_.IsClosed(); }
private:
void CheckDims(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
......@@ -71,15 +68,16 @@ class LoDTensorBlockingQueue {
"Expect input size is %d but found %s", dims_.size(),
lod_tensor_vec.size());
for (size_t i = 0; i < dims_.size(); ++i) {
const auto& in_dims = lod_tensor_vec[i].dims();
const auto& in_dims = framework::slice_ddim(
lod_tensor_vec[i].dims(), 1, lod_tensor_vec[i].dims().size());
const auto& expect_dims =
framework::slice_ddim(dims_[i], 1, dims_[i].size());
PADDLE_ENFORCE(in_dims == expect_dims,
"Dims of the %d-th input tensor does not match", i);
"Dims of the %d-th input tensor do not match", i);
}
}
std::unique_ptr<BlockingQueue<std::vector<framework::LoDTensor>>> queue_;
BlockingQueue<std::vector<framework::LoDTensor>> queue_;
std::vector<framework::DDim> dims_;
};
......@@ -92,8 +90,6 @@ class LoDTensorBlockingQueueHolder {
queue_.reset(new LoDTensorBlockingQueue(capacity, dims));
}
inline std::shared_ptr<LoDTensorBlockingQueue> GetQueue() { return queue_; }
inline const std::shared_ptr<LoDTensorBlockingQueue>& GetQueue() const {
return queue_;
}
......
......@@ -303,19 +303,16 @@ All parameter, weight, gradient are variables in Paddle.
using LoDTensorBlockingQueueHolder =
::paddle::operators::reader::LoDTensorBlockingQueueHolder;
py::class_<LoDTensorBlockingQueue>(m, "LoDTensorBlockingQueue", "")
.def("enqueue",
.def("push",
[](LoDTensorBlockingQueue &self,
const std::vector<framework::LoDTensor> &lod_tensor_vec) {
pybind11::gil_scoped_release release;
return self.Enqueue(lod_tensor_vec);
return self.Push(lod_tensor_vec);
})
.def("size",
[](const LoDTensorBlockingQueue &self) { return self.Size(); })
.def("capacity",
[](const LoDTensorBlockingQueue &self) { return self.Cap(); })
.def("close", [](LoDTensorBlockingQueue &self) { return self.Close(); })
.def("is_closed",
[](const LoDTensorBlockingQueue &self) { return self.IsClosed(); });
.def("size", &LoDTensorBlockingQueue::Size)
.def("capacity", &LoDTensorBlockingQueue::Cap)
.def("close", &LoDTensorBlockingQueue::Close)
.def("is_closed", &LoDTensorBlockingQueue::IsClosed);
m.def("init_lod_tensor_blocking_queue",
[](Variable &var, size_t capacity,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册