提交 251e4a8e 编写于 作者: T typhoonzero

unify fluid blocking queue

上级 6c0356e4
......@@ -17,36 +17,58 @@ limitations under the License. */
#include <condition_variable> // NOLINT
#include <deque>
#include <mutex> // NOLINT
#include <utility>
namespace paddle {
namespace operators {
namespace detail {
namespace framework {
template <typename T>
class SimpleBlockQueue {
private:
std::mutex mutex_;
std::condition_variable condition_;
std::deque<T> queue_;
class BlockingQueue {
public:
void Push(T const& value) {
void Push(const T &item) {
{
std::lock_guard<std::mutex> g(mutex_);
q_.emplace_back(item);
}
cv_.notify_one();
}
template <typename U>
void Extend(const U &items) {
{
std::unique_lock<std::mutex> lock(this->mutex_);
queue_.push_front(value);
std::lock_guard<std::mutex> g(mutex_);
for (auto &item : items) {
q_.emplace_back(item);
}
this->condition_.notify_one();
}
cv_.notify_all();
}
std::deque<T> PopAll(size_t ms, bool *timeout) {
auto time =
std::chrono::system_clock::now() + std::chrono::milliseconds(ms);
std::unique_lock<std::mutex> lock(mutex_);
*timeout = !cv_.wait_until(lock, time, [this] { return !q_.empty(); });
std::deque<T> ret;
if (!*timeout) {
std::swap(ret, q_);
}
return ret;
}
T Pop() {
std::unique_lock<std::mutex> lock(this->mutex_);
this->condition_.wait(lock, [=] { return !this->queue_.empty(); });
T rc(std::move(this->queue_.back()));
this->queue_.pop_back();
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !q_.empty(); });
T rc(std::move(q_.front()));
q_.pop_front();
return rc;
}
private:
std::mutex mutex_;
std::condition_variable cv_;
std::deque<T> q_;
};
} // namespace detail
} // namespace operators
} // namespace framework
} // namespace paddle
......@@ -22,6 +22,7 @@
#include <functional>
#include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
namespace paddle {
......@@ -30,46 +31,6 @@ class Scope;
namespace details {
template <typename T>
class BlockingQueue {
public:
void Push(const T &item) {
{
std::lock_guard<std::mutex> g(mutex_);
q_.emplace_back(item);
}
cv_.notify_one();
}
template <typename U>
void Extend(const U &items) {
{
std::lock_guard<std::mutex> g(mutex_);
for (auto &item : items) {
q_.emplace_back(item);
}
}
cv_.notify_all();
}
std::deque<T> PopAll(size_t ms, bool *timeout) {
auto time =
std::chrono::system_clock::now() + std::chrono::milliseconds(ms);
std::unique_lock<std::mutex> lock(mutex_);
*timeout = !cv_.wait_until(lock, time, [this] { return !q_.empty(); });
std::deque<T> ret;
if (!*timeout) {
std::swap(ret, q_);
}
return ret;
}
private:
std::mutex mutex_;
std::condition_variable cv_;
std::deque<T> q_;
};
class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
public:
ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
......
......@@ -29,12 +29,12 @@ limitations under the License. */
#include "grpc++/support/byte_buffer.h"
#include "grpc++/support/slice.h"
#include "grpc/support/log.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
namespace paddle {
namespace operators {
......
......@@ -90,7 +90,7 @@ class RequestGet final : public RequestBase {
::grpc::ServerCompletionQueue* cq,
framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
SimpleBlockQueue<MessageWithName>* queue)
framework::BlockingQueue<MessageWithName>* queue)
: RequestBase(service, cq, dev_ctx),
responder_(&ctx_),
scope_(scope),
......@@ -128,7 +128,7 @@ class RequestGet final : public RequestBase {
sendrecv::VariableMessage request_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_;
SimpleBlockQueue<MessageWithName>* queue_;
framework::BlockingQueue<MessageWithName>* queue_;
};
class RequestPrefetch final : public RequestBase {
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <utility>
#include "grpc++/grpc++.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -29,7 +30,6 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
namespace paddle {
namespace operators {
......@@ -37,7 +37,7 @@ namespace detail {
typedef std::pair<std::string, std::shared_ptr<VariableResponse>>
ReceivedMessage;
typedef SimpleBlockQueue<ReceivedMessage> ReceivedQueue;
typedef framework::BlockingQueue<ReceivedMessage> ReceivedQueue;
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
class RequestBase;
......@@ -99,7 +99,7 @@ class AsyncGRPCServer final {
const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_get_queue_;
framework::BlockingQueue<MessageWithName> var_get_queue_;
// client send variable to this queue.
ReceivedQueue var_recv_queue_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册