未验证 提交 d082f3a9 编写于 作者: Y Yi Wang 提交者: GitHub

Rewrite class Channel to implement buffered and unbuffered channels (#7915)

* Remove IsBounded as buffered channels have to be bounded

* Add derived classes Buffered and UnBuffered"

* Implement buffered and unbuffered channels

* Correct the syntax of Channel::Receive

* clang-format

* clang-format 3.8

* clang 3.8
上级 0311fd15
...@@ -98,3 +98,5 @@ if(NOT WITH_C_API AND WITH_FLUID) ...@@ -98,3 +98,5 @@ if(NOT WITH_C_API AND WITH_FLUID)
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/framework.pb.h DESTINATION include/paddle/framework) install(FILES ${CMAKE_CURRENT_BINARY_DIR}/framework.pb.h DESTINATION include/paddle/framework)
install(FILES details/cow_ptr.h details/op_registry.h DESTINATION include/paddle/framework/details) install(FILES details/cow_ptr.h details/op_registry.h DESTINATION include/paddle/framework/details)
endif() endif()
cc_test(channel_test SRCS channel_test.cc)
...@@ -13,75 +13,52 @@ See the License for the specific language governing permissions and ...@@ -13,75 +13,52 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <condition_variable>
#include <mutex> #include <stddef.h> // for size_t
#include <queue>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// Channel is the abstract class of buffered and un-buffered channels.
template <typename T> template <typename T>
class Channel { class Channel {
public: public:
explicit Channel(std::size_t capacity) : capacity_(capacity) {} virtual void Send(T*) = 0;
virtual void Receive(T*) = 0;
void Send(T* channel_element) { virtual size_t Cap() = 0;
std::unique_lock<std::mutex> lock(mu_);
if (IsBounded()) {
full_cond_var_.wait(lock, [this]() {
bool capacity_valid = capacity_ > 0 ? !IsCapacityFull() : true;
return capacity_valid;
});
}
channel_.push_back(std::move(*channel_element));
lock.unlock();
empty_cond_var_.notify_one();
}
T* Receive() { // Don't delete channels; instead, call Channel::Close.
std::unique_lock<std::mutex> lock(mu_); protected:
empty_cond_var_.wait(lock, [this]() { return !channel_.empty(); }); virtual ~Channel() {}
};
T* channel_element = std::move(channel_.front());
channel_.pop_front();
NotifyAllSenders(&lock);
return channel_element;
}
size_t Size() {
std::unique_lock<std::mutex> lock(mu_);
return channel_.size();
}
void Clear() { // Forward declaration of channel implementations.
std::unique_lock<std::mutex> lock(mu_); namespace details {
channel_.clear(); template <typename T>
class Buffered;
template <typename T>
class UnBuffered;
} // namespace details
NotifyAllSenders(&lock); template <typename T>
Channel<T>* MakeChannel(size_t buffer_size) {
if (buffer_size > 0) {
return new details::Buffered<T>(buffer_size);
} }
return new details::UnBuffered<T>();
}
private: template <typename T>
std::size_t capacity_; void CloseChannel(Channel<T>* ch) {
std::mutex mu_; if (ch->Cap() > 0) {
std::condition_variable empty_cond_var_; delete dynamic_cast<details::Buffered<T>*>(ch);
std::condition_variable full_cond_var_; } else {
std::deque<T> channel_; delete dynamic_cast<details::UnBuffered<T>*>(ch);
private:
void NotifyAllSenders(std::unique_lock<std::mutex>* lock) {
if (IsBounded()) {
lock->unlock();
full_cond_var_.notify_one();
}
} }
}
bool IsBounded() const { return capacity_ > 0; } } // namespace framework
bool IsCapacityFull() const { return channel_.size() >= capacity_; }
};
} // namespace operator
} // namespace paddle } // namespace paddle
#include "paddle/framework/details/buffered_channel.h"
#include "paddle/framework/details/unbuffered_channel.h"
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/channel.h"
#include "gtest/gtest.h"
TEST(Channel, MakeAndClose) {
using paddle::framework::Channel;
using paddle::framework::MakeChannel;
using paddle::framework::CloseChannel;
Channel<int>* ch = MakeChannel<int>(10);
CloseChannel(ch);
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <condition_variable>
#include <deque>
#include <mutex>
#include "paddle/framework/channel.h"
namespace paddle {
namespace framework {
namespace details {
template <typename T>
class Buffered : public paddle::framework::Channel<T> {
friend Channel<T>* paddle::framework::MakeChannel<T>(size_t);
friend void paddle::framework::CloseChannel<T>(Channel<T>*);
public:
virtual void Send(T*);
virtual void Receive(T*);
virtual size_t Cap() { return cap_; }
private:
size_t cap_;
std::mutex mu_;
std::condition_variable empty_cond_var_;
std::condition_variable full_cond_var_;
std::deque<T> channel_;
Buffered(size_t cap) : cap_(cap) {}
virtual ~Buffered();
void NotifyAllSenders(std::unique_lock<std::mutex>*);
};
template <typename T>
void Buffered<T>::Send(T* item) {
std::unique_lock<std::mutex> lock(mu_);
full_cond_var_.wait(lock, [this]() { return channel_.size() < cap_; });
channel_.push_back(std::move(*item));
lock.unlock();
empty_cond_var_.notify_one();
}
template <typename T>
void Buffered<T>::Receive(T* item) {
std::unique_lock<std::mutex> lock(mu_);
empty_cond_var_.wait(lock, [this]() { return !channel_.empty(); });
*item = std::move(channel_.front());
channel_.pop_front();
NotifyAllSenders(&lock);
}
template <typename T>
Buffered<T>::~Buffered() {
std::unique_lock<std::mutex> lock(mu_);
channel_.clear();
NotifyAllSenders(&lock);
}
template <typename T>
void Buffered<T>::NotifyAllSenders(std::unique_lock<std::mutex>* lock) {
lock->unlock();
full_cond_var_.notify_one();
}
} // namespace details
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <condition_variable>
#include <deque>
#include <mutex>
#include "paddle/framework/channel.h"
namespace paddle {
namespace framework {
namespace details {
template <typename T>
class UnBuffered : public paddle::framework::Channel<T> {
friend Channel<T>* paddle::framework::MakeChannel<T>(size_t);
friend void paddle::framework::CloseChannel<T>(Channel<T>*);
public:
virtual void Send(T*);
virtual void Receive(T*);
virtual size_t Cap() { return 0; }
private:
UnBuffered() {}
virtual ~UnBuffered();
};
template <typename T>
void UnBuffered<T>::Send(T* channel_element) {}
template <typename T>
void UnBuffered<T>::Receive(T*) {}
template <typename T>
UnBuffered<T>::~UnBuffered() {}
} // namespace details
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册