提交 f2154002 编写于 作者: K kavyasrinet 提交者: Abhinav Arora

Adding panic logic and test case (#8171)

* Adding panic logic and test case

* Change panic behavior to boolean instead of exception

* Adding atomic

* Switch to boolean

* Fix spacing

* Add to close method
上级 b9024492
...@@ -60,6 +60,16 @@ TEST(Channel, SufficientBufferSizeDoesntBlock) { ...@@ -60,6 +60,16 @@ TEST(Channel, SufficientBufferSizeDoesntBlock) {
delete ch; delete ch;
} }
TEST(Channel, SendOnClosedChannelPanics) {
const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size);
size_t i = 5;
EXPECT_EQ(ch->Send(&i), true); // should not block or panic
CloseChannel(ch);
EXPECT_EQ(ch->Send(&i), false); // should panic
delete ch;
}
TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) { TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) {
const size_t buffer_size = 10; const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size); auto ch = MakeChannel<size_t>(buffer_size);
...@@ -88,7 +98,6 @@ TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) { ...@@ -88,7 +98,6 @@ TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) {
// Note: we cannot check EXPECT_EQ(out, 0), because C++ doesn't // Note: we cannot check EXPECT_EQ(out, 0), because C++ doesn't
// define zero values like Go does. // define zero values like Go does.
} }
delete ch; delete ch;
} }
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <atomic>
#include <condition_variable> #include <condition_variable>
#include <deque> #include <deque>
#include <mutex> #include <mutex>
...@@ -42,7 +43,7 @@ class Buffered : public paddle::framework::Channel<T> { ...@@ -42,7 +43,7 @@ class Buffered : public paddle::framework::Channel<T> {
std::condition_variable empty_cond_var_; std::condition_variable empty_cond_var_;
std::condition_variable full_cond_var_; std::condition_variable full_cond_var_;
std::deque<T> channel_; std::deque<T> channel_;
bool closed_; std::atomic<bool> closed_{false};
Buffered(size_t cap) : cap_(cap), closed_(false) { Buffered(size_t cap) : cap_(cap), closed_(false) {
PADDLE_ENFORCE_GT(cap, 0); PADDLE_ENFORCE_GT(cap, 0);
...@@ -53,10 +54,13 @@ class Buffered : public paddle::framework::Channel<T> { ...@@ -53,10 +54,13 @@ class Buffered : public paddle::framework::Channel<T> {
template <typename T> template <typename T>
bool Buffered<T>::Send(T* item) { bool Buffered<T>::Send(T* item) {
bool ret = false;
if (closed_) {
return ret;
}
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
full_cond_var_.wait(lock, full_cond_var_.wait(lock,
[this]() { return channel_.size() < cap_ || closed_; }); [this]() { return channel_.size() < cap_ || closed_; });
bool ret = false;
if (!closed_) { if (!closed_) {
channel_.push_back(std::move(*item)); channel_.push_back(std::move(*item));
lock.unlock(); lock.unlock();
...@@ -82,6 +86,9 @@ bool Buffered<T>::Receive(T* item) { ...@@ -82,6 +86,9 @@ bool Buffered<T>::Receive(T* item) {
template <typename T> template <typename T>
void Buffered<T>::Close() { void Buffered<T>::Close() {
if (closed_) {
return;
}
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
closed_ = true; closed_ = true;
NotifyAllParticipants(&lock); NotifyAllParticipants(&lock);
......
...@@ -58,6 +58,10 @@ class UnBuffered : public paddle::framework::Channel<T> { ...@@ -58,6 +58,10 @@ class UnBuffered : public paddle::framework::Channel<T> {
// be sent from a writer to a reader. // be sent from a writer to a reader.
template <typename T> template <typename T>
bool UnBuffered<T>::Send(T* data) { bool UnBuffered<T>::Send(T* data) {
bool ret = false;
if (closed_) {
return ret;
}
// Prevent other writers from entering // Prevent other writers from entering
std::unique_lock<std::recursive_mutex> writer_lock(mu_write_); std::unique_lock<std::recursive_mutex> writer_lock(mu_write_);
writer_found_ = true; writer_found_ = true;
...@@ -66,7 +70,6 @@ bool UnBuffered<T>::Send(T* data) { ...@@ -66,7 +70,6 @@ bool UnBuffered<T>::Send(T* data) {
cv_writer_.wait(cv_lock, cv_writer_.wait(cv_lock,
[this]() { return reader_found_ == true || closed_; }); [this]() { return reader_found_ == true || closed_; });
cv_reader_.notify_one(); cv_reader_.notify_one();
bool ret = false;
if (!closed_) { if (!closed_) {
std::unique_lock<std::mutex> channel_lock(mu_ch_); std::unique_lock<std::mutex> channel_lock(mu_ch_);
item = data; item = data;
...@@ -114,6 +117,9 @@ bool UnBuffered<T>::Receive(T* data) { ...@@ -114,6 +117,9 @@ bool UnBuffered<T>::Receive(T* data) {
// that take place once the channel is closed. // that take place once the channel is closed.
template <typename T> template <typename T>
void UnBuffered<T>::Close() { void UnBuffered<T>::Close() {
if (closed_) {
return;
}
std::unique_lock<std::mutex> lock(mu_ch_); std::unique_lock<std::mutex> lock(mu_ch_);
item = nullptr; item = nullptr;
closed_ = true; closed_ = true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册