diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 280496984251919a8b4b6c52684f950a80b78356..318661af8bd04880577222fdc82cc1b6e79a457f 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -98,3 +98,5 @@ if(NOT WITH_C_API AND WITH_FLUID) install(FILES ${CMAKE_CURRENT_BINARY_DIR}/framework.pb.h DESTINATION include/paddle/framework) install(FILES details/cow_ptr.h details/op_registry.h DESTINATION include/paddle/framework/details) endif() + +cc_test(channel_test SRCS channel_test.cc) diff --git a/paddle/framework/channel.h b/paddle/framework/channel.h index 9ba0fc5c558a85b41deb01ad57842d9c4c054e0e..70ecccc1a1078374f3190b3956103ed8000c4fc5 100644 --- a/paddle/framework/channel.h +++ b/paddle/framework/channel.h @@ -13,75 +13,52 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include -#include -#include + +#include // for size_t namespace paddle { namespace framework { +// Channel is the abstract class of buffered and un-buffered channels. template class Channel { public: - explicit Channel(std::size_t capacity) : capacity_(capacity) {} - - void Send(T* channel_element) { - std::unique_lock lock(mu_); - - if (IsBounded()) { - full_cond_var_.wait(lock, [this]() { - bool capacity_valid = capacity_ > 0 ? !IsCapacityFull() : true; - return capacity_valid; - }); - } - channel_.push_back(std::move(*channel_element)); - - lock.unlock(); - empty_cond_var_.notify_one(); - } + virtual void Send(T*) = 0; + virtual void Receive(T*) = 0; + virtual size_t Cap() = 0; - T* Receive() { - std::unique_lock lock(mu_); - empty_cond_var_.wait(lock, [this]() { return !channel_.empty(); }); - - T* channel_element = std::move(channel_.front()); - channel_.pop_front(); - - NotifyAllSenders(&lock); - return channel_element; - } - - size_t Size() { - std::unique_lock lock(mu_); - return channel_.size(); - } + // Don't delete channels; instead, call Channel::Close. + protected: + virtual ~Channel() {} +}; - void Clear() { - std::unique_lock lock(mu_); - channel_.clear(); +// Forward declaration of channel implementations. +namespace details { +template +class Buffered; +template +class UnBuffered; +} // namespace details - NotifyAllSenders(&lock); +template +Channel* MakeChannel(size_t buffer_size) { + if (buffer_size > 0) { + return new details::Buffered(buffer_size); } + return new details::UnBuffered(); +} - private: - std::size_t capacity_; - std::mutex mu_; - std::condition_variable empty_cond_var_; - std::condition_variable full_cond_var_; - std::deque channel_; - - private: - void NotifyAllSenders(std::unique_lock* lock) { - if (IsBounded()) { - lock->unlock(); - full_cond_var_.notify_one(); - } +template +void CloseChannel(Channel* ch) { + if (ch->Cap() > 0) { + delete dynamic_cast*>(ch); + } else { + delete dynamic_cast*>(ch); } +} - bool IsBounded() const { return capacity_ > 0; } - - bool IsCapacityFull() const { return channel_.size() >= capacity_; } -}; - -} // namespace operator +} // namespace framework } // namespace paddle + +#include "paddle/framework/details/buffered_channel.h" +#include "paddle/framework/details/unbuffered_channel.h" diff --git a/paddle/framework/channel_test.cc b/paddle/framework/channel_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9efc0172658c800d14102531332dbef68fa392f4 --- /dev/null +++ b/paddle/framework/channel_test.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/channel.h" + +#include "gtest/gtest.h" + +TEST(Channel, MakeAndClose) { + using paddle::framework::Channel; + using paddle::framework::MakeChannel; + using paddle::framework::CloseChannel; + + Channel* ch = MakeChannel(10); + CloseChannel(ch); +} diff --git a/paddle/framework/details/buffered_channel.h b/paddle/framework/details/buffered_channel.h new file mode 100644 index 0000000000000000000000000000000000000000..572e29d44a3baec84a029d87f9b0874784aa761b --- /dev/null +++ b/paddle/framework/details/buffered_channel.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +#include "paddle/framework/channel.h" + +namespace paddle { +namespace framework { +namespace details { + +template +class Buffered : public paddle::framework::Channel { + friend Channel* paddle::framework::MakeChannel(size_t); + friend void paddle::framework::CloseChannel(Channel*); + + public: + virtual void Send(T*); + virtual void Receive(T*); + virtual size_t Cap() { return cap_; } + + private: + size_t cap_; + std::mutex mu_; + std::condition_variable empty_cond_var_; + std::condition_variable full_cond_var_; + std::deque channel_; + + Buffered(size_t cap) : cap_(cap) {} + virtual ~Buffered(); + + void NotifyAllSenders(std::unique_lock*); +}; + +template +void Buffered::Send(T* item) { + std::unique_lock lock(mu_); + full_cond_var_.wait(lock, [this]() { return channel_.size() < cap_; }); + channel_.push_back(std::move(*item)); + lock.unlock(); + empty_cond_var_.notify_one(); +} + +template +void Buffered::Receive(T* item) { + std::unique_lock lock(mu_); + empty_cond_var_.wait(lock, [this]() { return !channel_.empty(); }); + *item = std::move(channel_.front()); + channel_.pop_front(); + NotifyAllSenders(&lock); +} + +template +Buffered::~Buffered() { + std::unique_lock lock(mu_); + channel_.clear(); + NotifyAllSenders(&lock); +} + +template +void Buffered::NotifyAllSenders(std::unique_lock* lock) { + lock->unlock(); + full_cond_var_.notify_one(); +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/details/unbuffered_channel.h b/paddle/framework/details/unbuffered_channel.h new file mode 100644 index 0000000000000000000000000000000000000000..7ecced1fba88fea781fc342091bc71e5aa496d3a --- /dev/null +++ b/paddle/framework/details/unbuffered_channel.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +#include "paddle/framework/channel.h" + +namespace paddle { +namespace framework { +namespace details { + +template +class UnBuffered : public paddle::framework::Channel { + friend Channel* paddle::framework::MakeChannel(size_t); + friend void paddle::framework::CloseChannel(Channel*); + + public: + virtual void Send(T*); + virtual void Receive(T*); + virtual size_t Cap() { return 0; } + + private: + UnBuffered() {} + virtual ~UnBuffered(); +}; + +template +void UnBuffered::Send(T* channel_element) {} + +template +void UnBuffered::Receive(T*) {} + +template +UnBuffered::~UnBuffered() {} + +} // namespace details +} // namespace framework +} // namespace paddle