未验证 提交 3fdfa940 编写于 作者: W Wu Yi 提交者: GitHub

Merge pull request #10135 from typhoonzero/unify_blocking_queue

Unify fluid blocking queue
...@@ -17,36 +17,58 @@ limitations under the License. */ ...@@ -17,36 +17,58 @@ limitations under the License. */
#include <condition_variable> // NOLINT #include <condition_variable> // NOLINT
#include <deque> #include <deque>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <utility>
namespace paddle { namespace paddle {
namespace operators { namespace framework {
namespace detail {
template <typename T> template <typename T>
class SimpleBlockQueue { class BlockingQueue {
private:
std::mutex mutex_;
std::condition_variable condition_;
std::deque<T> queue_;
public: public:
void Push(T const& value) { void Push(const T &item) {
{
std::lock_guard<std::mutex> g(mutex_);
q_.emplace_back(item);
}
cv_.notify_one();
}
template <typename U>
void Extend(const U &items) {
{ {
std::unique_lock<std::mutex> lock(this->mutex_); std::lock_guard<std::mutex> g(mutex_);
queue_.push_front(value); for (auto &item : items) {
q_.emplace_back(item);
}
} }
this->condition_.notify_one(); cv_.notify_all();
}
std::deque<T> PopAll(size_t ms, bool *timeout) {
auto time =
std::chrono::system_clock::now() + std::chrono::milliseconds(ms);
std::unique_lock<std::mutex> lock(mutex_);
*timeout = !cv_.wait_until(lock, time, [this] { return !q_.empty(); });
std::deque<T> ret;
if (!*timeout) {
std::swap(ret, q_);
}
return ret;
} }
T Pop() { T Pop() {
std::unique_lock<std::mutex> lock(this->mutex_); std::unique_lock<std::mutex> lock(mutex_);
this->condition_.wait(lock, [=] { return !this->queue_.empty(); }); cv_.wait(lock, [=] { return !q_.empty(); });
T rc(std::move(this->queue_.back())); T rc(std::move(q_.front()));
this->queue_.pop_back(); q_.pop_front();
return rc; return rc;
} }
private:
std::mutex mutex_;
std::condition_variable cv_;
std::deque<T> q_;
}; };
} // namespace detail } // namespace framework
} // namespace operators
} // namespace paddle } // namespace paddle
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <functional> #include <functional>
#include "ThreadPool.h" // ThreadPool in thrird party #include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h"
namespace paddle { namespace paddle {
...@@ -30,46 +31,6 @@ class Scope; ...@@ -30,46 +31,6 @@ class Scope;
namespace details { namespace details {
template <typename T>
class BlockingQueue {
public:
void Push(const T &item) {
{
std::lock_guard<std::mutex> g(mutex_);
q_.emplace_back(item);
}
cv_.notify_one();
}
template <typename U>
void Extend(const U &items) {
{
std::lock_guard<std::mutex> g(mutex_);
for (auto &item : items) {
q_.emplace_back(item);
}
}
cv_.notify_all();
}
std::deque<T> PopAll(size_t ms, bool *timeout) {
auto time =
std::chrono::system_clock::now() + std::chrono::milliseconds(ms);
std::unique_lock<std::mutex> lock(mutex_);
*timeout = !cv_.wait_until(lock, time, [this] { return !q_.empty(); });
std::deque<T> ret;
if (!*timeout) {
std::swap(ret, q_);
}
return ret;
}
private:
std::mutex mutex_;
std::condition_variable cv_;
std::deque<T> q_;
};
class ThreadedSSAGraphExecutor : public SSAGraphExecutor { class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
public: public:
ThreadedSSAGraphExecutor(size_t num_threads, bool use_event, ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
......
...@@ -29,12 +29,12 @@ limitations under the License. */ ...@@ -29,12 +29,12 @@ limitations under the License. */
#include "grpc++/support/byte_buffer.h" #include "grpc++/support/byte_buffer.h"
#include "grpc++/support/slice.h" #include "grpc++/support/slice.h"
#include "grpc/support/log.h" #include "grpc/support/log.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -90,7 +90,7 @@ class RequestGet final : public RequestBase { ...@@ -90,7 +90,7 @@ class RequestGet final : public RequestBase {
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq,
framework::Scope* scope, framework::Scope* scope,
const platform::DeviceContext* dev_ctx, const platform::DeviceContext* dev_ctx,
SimpleBlockQueue<MessageWithName>* queue) framework::BlockingQueue<MessageWithName>* queue)
: RequestBase(service, cq, dev_ctx), : RequestBase(service, cq, dev_ctx),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), scope_(scope),
...@@ -128,7 +128,7 @@ class RequestGet final : public RequestBase { ...@@ -128,7 +128,7 @@ class RequestGet final : public RequestBase {
sendrecv::VariableMessage request_; sendrecv::VariableMessage request_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_; framework::Scope* scope_;
SimpleBlockQueue<MessageWithName>* queue_; framework::BlockingQueue<MessageWithName>* queue_;
}; };
class RequestPrefetch final : public RequestBase { class RequestPrefetch final : public RequestBase {
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include <utility> #include <utility>
#include "grpc++/grpc++.h" #include "grpc++/grpc++.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -29,7 +30,6 @@ limitations under the License. */ ...@@ -29,7 +30,6 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -37,7 +37,7 @@ namespace detail { ...@@ -37,7 +37,7 @@ namespace detail {
typedef std::pair<std::string, std::shared_ptr<VariableResponse>> typedef std::pair<std::string, std::shared_ptr<VariableResponse>>
ReceivedMessage; ReceivedMessage;
typedef SimpleBlockQueue<ReceivedMessage> ReceivedQueue; typedef framework::BlockingQueue<ReceivedMessage> ReceivedQueue;
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName; typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
class RequestBase; class RequestBase;
...@@ -99,7 +99,7 @@ class AsyncGRPCServer final { ...@@ -99,7 +99,7 @@ class AsyncGRPCServer final {
const platform::DeviceContext *dev_ctx_; const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue. // received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_get_queue_; framework::BlockingQueue<MessageWithName> var_get_queue_;
// client send variable to this queue. // client send variable to this queue.
ReceivedQueue var_recv_queue_; ReceivedQueue var_recv_queue_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册