diff --git a/paddle/framework/channel.h b/paddle/framework/channel.h index 70ecccc1a1078374f3190b3956103ed8000c4fc5..0570980c5a4d7fa45e672ae5baac65d2c65ddad9 100644 --- a/paddle/framework/channel.h +++ b/paddle/framework/channel.h @@ -26,9 +26,7 @@ class Channel { virtual void Send(T*) = 0; virtual void Receive(T*) = 0; virtual size_t Cap() = 0; - - // Don't delete channels; instead, call Channel::Close. - protected: + virtual void Close() = 0; virtual ~Channel() {} }; @@ -50,11 +48,7 @@ Channel* MakeChannel(size_t buffer_size) { template void CloseChannel(Channel* ch) { - if (ch->Cap() > 0) { - delete dynamic_cast*>(ch); - } else { - delete dynamic_cast*>(ch); - } + ch->Close(); } } // namespace framework diff --git a/paddle/framework/channel_test.cc b/paddle/framework/channel_test.cc index 9efc0172658c800d14102531332dbef68fa392f4..1510fb8abf54f05804bd404d9bd00ecc42fbef63 100644 --- a/paddle/framework/channel_test.cc +++ b/paddle/framework/channel_test.cc @@ -14,13 +14,67 @@ limitations under the License. */ #include "paddle/framework/channel.h" +#include +#include + #include "gtest/gtest.h" +using paddle::framework::Channel; +using paddle::framework::MakeChannel; +using paddle::framework::CloseChannel; + TEST(Channel, MakeAndClose) { - using paddle::framework::Channel; - using paddle::framework::MakeChannel; - using paddle::framework::CloseChannel; + using paddle::framework::details::Buffered; + using paddle::framework::details::UnBuffered; + { + // MakeChannel should return a buffered channel is buffer_size > 0. + auto ch = MakeChannel(10); + EXPECT_NE(dynamic_cast*>(ch), nullptr); + EXPECT_EQ(dynamic_cast*>(ch), nullptr); + CloseChannel(ch); + delete ch; + } + { + // MakeChannel should return an un-buffered channel is buffer_size = 0. + auto ch = MakeChannel(0); + EXPECT_EQ(dynamic_cast*>(ch), nullptr); + EXPECT_NE(dynamic_cast*>(ch), nullptr); + CloseChannel(ch); + delete ch; + } +} + +TEST(Channel, SufficientBufferSizeDoesntBlock) { + const size_t buffer_size = 10; + auto ch = MakeChannel(buffer_size); + for (size_t i = 0; i < buffer_size; ++i) { + ch->Send(&i); // should not block + } + + size_t out; + for (size_t i = 0; i < buffer_size; ++i) { + ch->Receive(&out); // should not block + EXPECT_EQ(out, i); + } + CloseChannel(ch); + delete ch; +} + +TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) { + const size_t buffer_size = 10; + auto ch = MakeChannel(buffer_size); + size_t sum = 0; + std::thread t([&]() { + // Try to write more than buffer size. + for (size_t i = 0; i < 2 * buffer_size; ++i) { + ch->Send(&i); // should not block + sum += i; + } + }); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.5 sec + EXPECT_EQ(sum, 45U); - Channel* ch = MakeChannel(10); CloseChannel(ch); + t.join(); + delete ch; } diff --git a/paddle/framework/details/buffered_channel.h b/paddle/framework/details/buffered_channel.h index 572e29d44a3baec84a029d87f9b0874784aa761b..b093e1589293b030ef2bedb82504a8e86b3dc857 100644 --- a/paddle/framework/details/buffered_channel.h +++ b/paddle/framework/details/buffered_channel.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/framework/channel.h" +#include "paddle/platform/enforce.h" namespace paddle { namespace framework { @@ -32,6 +33,8 @@ class Buffered : public paddle::framework::Channel { virtual void Send(T*); virtual void Receive(T*); virtual size_t Cap() { return cap_; } + virtual void Close(); + virtual ~Buffered(); private: size_t cap_; @@ -39,9 +42,11 @@ class Buffered : public paddle::framework::Channel { std::condition_variable empty_cond_var_; std::condition_variable full_cond_var_; std::deque channel_; + bool closed_; - Buffered(size_t cap) : cap_(cap) {} - virtual ~Buffered(); + Buffered(size_t cap) : cap_(cap), closed_(false) { + PADDLE_ENFORCE_GT(cap, 0); + } void NotifyAllSenders(std::unique_lock*); }; @@ -49,24 +54,39 @@ class Buffered : public paddle::framework::Channel { template void Buffered::Send(T* item) { std::unique_lock lock(mu_); - full_cond_var_.wait(lock, [this]() { return channel_.size() < cap_; }); - channel_.push_back(std::move(*item)); - lock.unlock(); - empty_cond_var_.notify_one(); + full_cond_var_.wait(lock, + [this]() { return channel_.size() < cap_ || closed_; }); + if (!closed_) { + channel_.push_back(std::move(*item)); + lock.unlock(); + empty_cond_var_.notify_one(); + } } template void Buffered::Receive(T* item) { std::unique_lock lock(mu_); - empty_cond_var_.wait(lock, [this]() { return !channel_.empty(); }); - *item = std::move(channel_.front()); - channel_.pop_front(); + empty_cond_var_.wait(lock, [this]() { return !channel_.empty() || closed_; }); + if (!closed_) { + *item = std::move(channel_.front()); + channel_.pop_front(); + NotifyAllSenders(&lock); + } else { + item = nullptr; + } +} + +template +void Buffered::Close() { + std::unique_lock lock(mu_); + closed_ = true; NotifyAllSenders(&lock); } template Buffered::~Buffered() { std::unique_lock lock(mu_); + closed_ = true; channel_.clear(); NotifyAllSenders(&lock); } @@ -74,7 +94,7 @@ Buffered::~Buffered() { template void Buffered::NotifyAllSenders(std::unique_lock* lock) { lock->unlock(); - full_cond_var_.notify_one(); + full_cond_var_.notify_all(); } } // namespace details diff --git a/paddle/framework/details/unbuffered_channel.h b/paddle/framework/details/unbuffered_channel.h index 7ecced1fba88fea781fc342091bc71e5aa496d3a..cc2d2e587eca981307d4e522bd569fbffa450207 100644 --- a/paddle/framework/details/unbuffered_channel.h +++ b/paddle/framework/details/unbuffered_channel.h @@ -32,10 +32,11 @@ class UnBuffered : public paddle::framework::Channel { virtual void Send(T*); virtual void Receive(T*); virtual size_t Cap() { return 0; } + virtual void Close(); + virtual ~UnBuffered(); private: UnBuffered() {} - virtual ~UnBuffered(); }; template @@ -44,6 +45,9 @@ void UnBuffered::Send(T* channel_element) {} template void UnBuffered::Receive(T*) {} +template +void UnBuffered::Close() {} + template UnBuffered::~UnBuffered() {}