提交 adf14b0c 编写于 作者: C chengduo 提交者: Yi Wang

Refine channel test (#7946)

* refine channel test

* follow comments

* Add dependency enforce to threadpool

* Revert changes to channel_test.cc

* Revert changes to channel_test.cc

* Add #include "paddle/framework/macros.h"

* Add unit tests

* fix code format

* refine close channel

* follow comments

* use delete to destroy channel
上级 6f7eb0d5
...@@ -26,9 +26,7 @@ class Channel { ...@@ -26,9 +26,7 @@ class Channel {
virtual void Send(T*) = 0; virtual void Send(T*) = 0;
virtual void Receive(T*) = 0; virtual void Receive(T*) = 0;
virtual size_t Cap() = 0; virtual size_t Cap() = 0;
virtual void Close() = 0;
// Don't delete channels; instead, call Channel::Close.
protected:
virtual ~Channel() {} virtual ~Channel() {}
}; };
...@@ -50,11 +48,7 @@ Channel<T>* MakeChannel(size_t buffer_size) { ...@@ -50,11 +48,7 @@ Channel<T>* MakeChannel(size_t buffer_size) {
template <typename T> template <typename T>
void CloseChannel(Channel<T>* ch) { void CloseChannel(Channel<T>* ch) {
if (ch->Cap() > 0) { ch->Close();
delete dynamic_cast<details::Buffered<T>*>(ch);
} else {
delete dynamic_cast<details::UnBuffered<T>*>(ch);
}
} }
} // namespace framework } // namespace framework
......
...@@ -14,13 +14,67 @@ limitations under the License. */ ...@@ -14,13 +14,67 @@ limitations under the License. */
#include "paddle/framework/channel.h" #include "paddle/framework/channel.h"
#include <chrono>
#include <thread>
#include "gtest/gtest.h" #include "gtest/gtest.h"
using paddle::framework::Channel;
using paddle::framework::MakeChannel;
using paddle::framework::CloseChannel;
TEST(Channel, MakeAndClose) { TEST(Channel, MakeAndClose) {
using paddle::framework::Channel; using paddle::framework::details::Buffered;
using paddle::framework::MakeChannel; using paddle::framework::details::UnBuffered;
using paddle::framework::CloseChannel; {
// MakeChannel should return a buffered channel is buffer_size > 0.
auto ch = MakeChannel<int>(10);
EXPECT_NE(dynamic_cast<Buffered<int>*>(ch), nullptr);
EXPECT_EQ(dynamic_cast<UnBuffered<int>*>(ch), nullptr);
CloseChannel(ch);
delete ch;
}
{
// MakeChannel should return an un-buffered channel is buffer_size = 0.
auto ch = MakeChannel<int>(0);
EXPECT_EQ(dynamic_cast<Buffered<int>*>(ch), nullptr);
EXPECT_NE(dynamic_cast<UnBuffered<int>*>(ch), nullptr);
CloseChannel(ch);
delete ch;
}
}
TEST(Channel, SufficientBufferSizeDoesntBlock) {
const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size);
for (size_t i = 0; i < buffer_size; ++i) {
ch->Send(&i); // should not block
}
size_t out;
for (size_t i = 0; i < buffer_size; ++i) {
ch->Receive(&out); // should not block
EXPECT_EQ(out, i);
}
CloseChannel(ch);
delete ch;
}
TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) {
const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size);
size_t sum = 0;
std::thread t([&]() {
// Try to write more than buffer size.
for (size_t i = 0; i < 2 * buffer_size; ++i) {
ch->Send(&i); // should not block
sum += i;
}
});
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.5 sec
EXPECT_EQ(sum, 45U);
Channel<int>* ch = MakeChannel<int>(10);
CloseChannel(ch); CloseChannel(ch);
t.join();
delete ch;
} }
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <mutex> #include <mutex>
#include "paddle/framework/channel.h" #include "paddle/framework/channel.h"
#include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -32,6 +33,8 @@ class Buffered : public paddle::framework::Channel<T> { ...@@ -32,6 +33,8 @@ class Buffered : public paddle::framework::Channel<T> {
virtual void Send(T*); virtual void Send(T*);
virtual void Receive(T*); virtual void Receive(T*);
virtual size_t Cap() { return cap_; } virtual size_t Cap() { return cap_; }
virtual void Close();
virtual ~Buffered();
private: private:
size_t cap_; size_t cap_;
...@@ -39,9 +42,11 @@ class Buffered : public paddle::framework::Channel<T> { ...@@ -39,9 +42,11 @@ class Buffered : public paddle::framework::Channel<T> {
std::condition_variable empty_cond_var_; std::condition_variable empty_cond_var_;
std::condition_variable full_cond_var_; std::condition_variable full_cond_var_;
std::deque<T> channel_; std::deque<T> channel_;
bool closed_;
Buffered(size_t cap) : cap_(cap) {} Buffered(size_t cap) : cap_(cap), closed_(false) {
virtual ~Buffered(); PADDLE_ENFORCE_GT(cap, 0);
}
void NotifyAllSenders(std::unique_lock<std::mutex>*); void NotifyAllSenders(std::unique_lock<std::mutex>*);
}; };
...@@ -49,24 +54,39 @@ class Buffered : public paddle::framework::Channel<T> { ...@@ -49,24 +54,39 @@ class Buffered : public paddle::framework::Channel<T> {
template <typename T> template <typename T>
void Buffered<T>::Send(T* item) { void Buffered<T>::Send(T* item) {
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
full_cond_var_.wait(lock, [this]() { return channel_.size() < cap_; }); full_cond_var_.wait(lock,
channel_.push_back(std::move(*item)); [this]() { return channel_.size() < cap_ || closed_; });
lock.unlock(); if (!closed_) {
empty_cond_var_.notify_one(); channel_.push_back(std::move(*item));
lock.unlock();
empty_cond_var_.notify_one();
}
} }
template <typename T> template <typename T>
void Buffered<T>::Receive(T* item) { void Buffered<T>::Receive(T* item) {
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
empty_cond_var_.wait(lock, [this]() { return !channel_.empty(); }); empty_cond_var_.wait(lock, [this]() { return !channel_.empty() || closed_; });
*item = std::move(channel_.front()); if (!closed_) {
channel_.pop_front(); *item = std::move(channel_.front());
channel_.pop_front();
NotifyAllSenders(&lock);
} else {
item = nullptr;
}
}
template <typename T>
void Buffered<T>::Close() {
std::unique_lock<std::mutex> lock(mu_);
closed_ = true;
NotifyAllSenders(&lock); NotifyAllSenders(&lock);
} }
template <typename T> template <typename T>
Buffered<T>::~Buffered() { Buffered<T>::~Buffered() {
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
closed_ = true;
channel_.clear(); channel_.clear();
NotifyAllSenders(&lock); NotifyAllSenders(&lock);
} }
...@@ -74,7 +94,7 @@ Buffered<T>::~Buffered() { ...@@ -74,7 +94,7 @@ Buffered<T>::~Buffered() {
template <typename T> template <typename T>
void Buffered<T>::NotifyAllSenders(std::unique_lock<std::mutex>* lock) { void Buffered<T>::NotifyAllSenders(std::unique_lock<std::mutex>* lock) {
lock->unlock(); lock->unlock();
full_cond_var_.notify_one(); full_cond_var_.notify_all();
} }
} // namespace details } // namespace details
......
...@@ -32,10 +32,11 @@ class UnBuffered : public paddle::framework::Channel<T> { ...@@ -32,10 +32,11 @@ class UnBuffered : public paddle::framework::Channel<T> {
virtual void Send(T*); virtual void Send(T*);
virtual void Receive(T*); virtual void Receive(T*);
virtual size_t Cap() { return 0; } virtual size_t Cap() { return 0; }
virtual void Close();
virtual ~UnBuffered();
private: private:
UnBuffered() {} UnBuffered() {}
virtual ~UnBuffered();
}; };
template <typename T> template <typename T>
...@@ -44,6 +45,9 @@ void UnBuffered<T>::Send(T* channel_element) {} ...@@ -44,6 +45,9 @@ void UnBuffered<T>::Send(T* channel_element) {}
template <typename T> template <typename T>
void UnBuffered<T>::Receive(T*) {} void UnBuffered<T>::Receive(T*) {}
template <typename T>
void UnBuffered<T>::Close() {}
template <typename T> template <typename T>
UnBuffered<T>::~UnBuffered() {} UnBuffered<T>::~UnBuffered() {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册