未验证 提交 41894da1 编写于 作者: A Abhinav Arora 提交者: GitHub

Add changes to channel that are needed for select op (#9084)

上级 a4b0e4a1
...@@ -15,23 +15,43 @@ limitations under the License. */ ...@@ -15,23 +15,43 @@ limitations under the License. */
#pragma once #pragma once
#include <stddef.h> // for size_t #include <stddef.h> // for size_t
#include <condition_variable>
#include <typeindex> #include <typeindex>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
enum class ChannelAction {
SEND = 0,
RECEIVE = 1,
CLOSE = 2,
};
// Channel is the abstract class of buffered and un-buffered channels. // Channel is the abstract class of buffered and un-buffered channels.
template <typename T> template <typename T>
class Channel { class Channel {
public: public:
virtual bool CanSend() = 0;
virtual bool CanReceive() = 0;
virtual bool Send(T*) = 0; virtual bool Send(T*) = 0;
virtual bool Receive(T*) = 0; virtual bool Receive(T*) = 0;
virtual size_t Cap() = 0; virtual size_t Cap() = 0;
virtual void Lock() = 0; virtual void Lock() = 0;
virtual void Unlock() = 0; virtual void Unlock() = 0;
virtual bool IsClosed() = 0;
virtual void Close() = 0; virtual void Close() = 0;
virtual ~Channel() {} virtual ~Channel() {}
virtual void AddToSendQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) = 0;
virtual void AddToReceiveQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) = 0;
virtual void RemoveFromSendQ(const void* referrer) = 0;
virtual void RemoveFromReceiveQ(const void* referrer) = 0;
}; };
// Forward declaration of channel implementations. // Forward declaration of channel implementations.
...@@ -80,6 +100,27 @@ class ChannelHolder { ...@@ -80,6 +100,27 @@ class ChannelHolder {
return channel != nullptr ? channel->Receive(data) : false; return channel != nullptr ? channel->Receive(data) : false;
} }
bool IsClosed() {
if (IsInitialized()) {
return holder_->IsClosed();
}
return false;
}
bool CanSend() {
if (IsInitialized()) {
return holder_->CanSend();
}
return false;
}
bool CanReceive() {
if (IsInitialized()) {
return holder_->CanReceive();
}
return false;
}
void close() { void close() {
if (IsInitialized()) holder_->Close(); if (IsInitialized()) holder_->Close();
} }
...@@ -97,6 +138,50 @@ class ChannelHolder { ...@@ -97,6 +138,50 @@ class ChannelHolder {
if (IsInitialized()) holder_->Unlock(); if (IsInitialized()) holder_->Unlock();
} }
template <typename T>
void AddToSendQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
if (IsInitialized()) {
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToSendQ(referrer, data, cond, cb);
}
}
}
template <typename T>
void AddToReceiveQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
if (IsInitialized()) {
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToReceiveQ(referrer, data, cond, cb);
}
}
}
template <typename T>
void RemoveFromSendQ(const void* referrer) {
if (IsInitialized()) {
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->RemoveFromSendQ(referrer);
}
}
}
template <typename T>
void RemoveFromReceiveQ(const void* referrer) {
if (IsInitialized()) {
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->RemoveFromReceiveQ(referrer);
}
}
}
inline bool IsInitialized() const { return holder_ != nullptr; } inline bool IsInitialized() const { return holder_ != nullptr; }
inline const std::type_index Type() { inline const std::type_index Type() {
...@@ -113,6 +198,9 @@ class ChannelHolder { ...@@ -113,6 +198,9 @@ class ChannelHolder {
virtual ~Placeholder() {} virtual ~Placeholder() {}
virtual const std::type_index Type() const = 0; virtual const std::type_index Type() const = 0;
virtual void* Ptr() const = 0; virtual void* Ptr() const = 0;
virtual bool IsClosed() = 0;
virtual bool CanSend() = 0;
virtual bool CanReceive() = 0;
virtual void Close() = 0; virtual void Close() = 0;
virtual void Lock() = 0; virtual void Lock() = 0;
virtual void Unlock() = 0; virtual void Unlock() = 0;
...@@ -129,6 +217,27 @@ class ChannelHolder { ...@@ -129,6 +217,27 @@ class ChannelHolder {
virtual void* Ptr() const { return static_cast<void*>(channel_.get()); } virtual void* Ptr() const { return static_cast<void*>(channel_.get()); }
virtual bool IsClosed() {
if (channel_) {
return channel_->IsClosed();
}
return false;
}
virtual bool CanSend() {
if (channel_) {
return channel_->CanSend();
}
return false;
}
virtual bool CanReceive() {
if (channel_) {
return channel_->CanReceive();
}
return false;
}
virtual void Close() { virtual void Close() {
if (channel_) channel_->Close(); if (channel_) channel_->Close();
} }
......
...@@ -29,32 +29,50 @@ class ChannelImpl : public paddle::framework::Channel<T> { ...@@ -29,32 +29,50 @@ class ChannelImpl : public paddle::framework::Channel<T> {
friend void paddle::framework::CloseChannel<T>(Channel<T> *); friend void paddle::framework::CloseChannel<T>(Channel<T> *);
public: public:
virtual bool CanSend();
virtual bool CanReceive();
virtual bool Send(T *); virtual bool Send(T *);
virtual bool Receive(T *); virtual bool Receive(T *);
virtual size_t Cap() { return cap_; } virtual size_t Cap() { return cap_; }
virtual void Lock(); virtual void Lock();
virtual void Unlock(); virtual void Unlock();
virtual bool IsClosed();
virtual void Close(); virtual void Close();
ChannelImpl(size_t); ChannelImpl(size_t);
virtual ~ChannelImpl(); virtual ~ChannelImpl();
virtual void AddToSendQ(const void *referrer, T *data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb);
virtual void AddToReceiveQ(const void *referrer, T *data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb);
virtual void RemoveFromSendQ(const void *referrer);
virtual void RemoveFromReceiveQ(const void *referrer);
private: private:
struct QueueMessage { struct QueueMessage {
T *data; T *data;
std::condition_variable_any cond; std::shared_ptr<std::condition_variable_any> cond;
bool chan_closed = false; bool chan_closed = false;
bool completed = false; bool completed = false;
const void *referrer; // TODO(thuan): figure out better way to do this
std::function<bool(ChannelAction)> callback;
QueueMessage(T *item) : data(item) {} QueueMessage(T *item)
: data(item), cond(std::make_shared<std::condition_variable_any>()) {}
QueueMessage(T *item, std::shared_ptr<std::condition_variable_any> cond)
: data(item), cond(cond) {}
void Wait(std::unique_lock<std::recursive_mutex> &lock) { void Wait(std::unique_lock<std::recursive_mutex> &lock) {
cond.wait(lock, [this]() { return completed; }); cond->wait(lock, [this]() { return completed; });
} }
void Notify() { void Notify() {
completed = true; completed = true;
cond.notify_all(); cond->notify_all();
} }
}; };
...@@ -87,6 +105,18 @@ ChannelImpl<T>::ChannelImpl(size_t capacity) ...@@ -87,6 +105,18 @@ ChannelImpl<T>::ChannelImpl(size_t capacity)
PADDLE_ENFORCE_GE(capacity, 0); PADDLE_ENFORCE_GE(capacity, 0);
} }
template <typename T>
bool ChannelImpl<T>::CanSend() {
std::lock_guard<std::recursive_mutex> lock{mu_};
return !closed_ && (!recvq.empty() || buf_.size() < cap_);
}
template <typename T>
bool ChannelImpl<T>::CanReceive() {
std::lock_guard<std::recursive_mutex> lock{mu_};
return !(closed_ && buf_.empty()) && (!sendq.empty() || buf_.size() > 0);
}
template <typename T> template <typename T>
bool ChannelImpl<T>::Send(T *item) { bool ChannelImpl<T>::Send(T *item) {
send_ctr++; send_ctr++;
...@@ -105,7 +135,24 @@ bool ChannelImpl<T>::Send(T *item) { ...@@ -105,7 +135,24 @@ bool ChannelImpl<T>::Send(T *item) {
std::shared_ptr<QueueMessage> m = recvq.front(); std::shared_ptr<QueueMessage> m = recvq.front();
recvq.pop_front(); recvq.pop_front();
// Do the data transfer // Do the data transfer
*(m->data) = std::move(*item); // We will do this data transfer if either of the following
// cases are true
// 1. callback == nullptr // This means it was a regular channel send
// 2. callback returns true
bool do_send = true;
if (m->callback != nullptr) do_send = m->callback(ChannelAction::SEND);
if (do_send)
*(m->data) = std::move(*item);
else
// We cannot do the data transfer because
// this QueueMessage was added by Select
// and some other case was executed.
// So call the Send function again.
// We do not care about notifying other
// because they would have been notified
// by the executed select case.
return Send(item);
// Wake up the blocked process and unlock // Wake up the blocked process and unlock
m->Notify(); m->Notify();
lock.unlock(); lock.unlock();
...@@ -150,7 +197,25 @@ bool ChannelImpl<T>::Receive(T *item) { ...@@ -150,7 +197,25 @@ bool ChannelImpl<T>::Receive(T *item) {
std::shared_ptr<QueueMessage> m = sendq.front(); std::shared_ptr<QueueMessage> m = sendq.front();
sendq.pop_front(); sendq.pop_front();
// Do the data transfer // Do the data transfer
*item = std::move(*(m->data)); // We will do this data transfer if either of the following
// cases are true
// 1. callback == nullptr // This means it was a regular channel send
// 2. callback returns true
bool do_receive = true;
if (m->callback != nullptr)
do_receive = m->callback(ChannelAction::RECEIVE);
if (do_receive)
*item = std::move(*(m->data));
else
// We cannot do the data transfer because
// this QueueMessage was added by Select
// and some other case was executed.
// So call the Receive function again.
// We do not care about notifying other
// because they would have been notified
// by the executed select case.
return Receive(item);
// Wake up the blocked process and unlock // Wake up the blocked process and unlock
m->Notify(); m->Notify();
lock.unlock(); lock.unlock();
...@@ -186,6 +251,12 @@ void ChannelImpl<T>::Unlock() { ...@@ -186,6 +251,12 @@ void ChannelImpl<T>::Unlock() {
mu_.unlock(); mu_.unlock();
} }
template <typename T>
bool ChannelImpl<T>::IsClosed() {
std::lock_guard<std::recursive_mutex> lock{mu_};
return closed_;
}
template <typename T> template <typename T>
void ChannelImpl<T>::Close() { void ChannelImpl<T>::Close() {
std::unique_lock<std::recursive_mutex> lock{mu_}; std::unique_lock<std::recursive_mutex> lock{mu_};
...@@ -203,6 +274,12 @@ void ChannelImpl<T>::Close() { ...@@ -203,6 +274,12 @@ void ChannelImpl<T>::Close() {
std::shared_ptr<QueueMessage> m = recvq.front(); std::shared_ptr<QueueMessage> m = recvq.front();
recvq.pop_front(); recvq.pop_front();
m->chan_closed = true; m->chan_closed = true;
// Execute callback function (if any)
if (m->callback != nullptr) {
m->callback(ChannelAction::CLOSE);
}
m->Notify(); m->Notify();
} }
...@@ -211,10 +288,72 @@ void ChannelImpl<T>::Close() { ...@@ -211,10 +288,72 @@ void ChannelImpl<T>::Close() {
std::shared_ptr<QueueMessage> m = sendq.front(); std::shared_ptr<QueueMessage> m = sendq.front();
sendq.pop_front(); sendq.pop_front();
m->chan_closed = true; m->chan_closed = true;
// Execute callback function (if any)
if (m->callback != nullptr) {
m->callback(ChannelAction::CLOSE);
}
m->Notify(); m->Notify();
} }
} }
template <typename T>
void ChannelImpl<T>::AddToSendQ(
const void *referrer, T *data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
std::lock_guard<std::recursive_mutex> lock{mu_};
auto m = std::make_shared<QueueMessage>(data, cond);
m->referrer = referrer;
m->callback = cb;
sendq.push_back(m);
}
template <typename T>
void ChannelImpl<T>::AddToReceiveQ(
const void *referrer, T *data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
std::lock_guard<std::recursive_mutex> lock{mu_};
auto m = std::make_shared<QueueMessage>(data, cond);
m->referrer = referrer;
m->callback = cb;
recvq.push_back(m);
}
template <typename T>
void ChannelImpl<T>::RemoveFromSendQ(const void *referrer) {
std::lock_guard<std::recursive_mutex> lock{mu_};
for (auto it = sendq.begin(); it != sendq.end();) {
std::shared_ptr<QueueMessage> sendMsg = (std::shared_ptr<QueueMessage>)*it;
if (sendMsg->referrer == referrer) {
it = sendq.erase(it);
send_ctr--;
} else {
++it;
}
}
}
template <typename T>
void ChannelImpl<T>::RemoveFromReceiveQ(const void *referrer) {
std::lock_guard<std::recursive_mutex> lock{mu_};
for (auto it = recvq.begin(); it != recvq.end();) {
std::shared_ptr<QueueMessage> recvMsg = (std::shared_ptr<QueueMessage>)*it;
if (recvMsg->referrer == referrer) {
it = recvq.erase(it);
recv_ctr--;
} else {
++it;
}
}
}
template <typename T> template <typename T>
ChannelImpl<T>::~ChannelImpl() { ChannelImpl<T>::~ChannelImpl() {
Close(); Close();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册