提交 f2154002 编写于 作者: K kavyasrinet 提交者: Abhinav Arora

Adding panic logic and test case (#8171)

* Adding panic logic and test case

* Change panic behavior to boolean instead of exception

* Adding atomic

* Switch to boolean

* Fix spacing

* Add to close method
上级 b9024492
......@@ -60,6 +60,16 @@ TEST(Channel, SufficientBufferSizeDoesntBlock) {
delete ch;
}
TEST(Channel, SendOnClosedChannelPanics) {
const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size);
size_t i = 5;
EXPECT_EQ(ch->Send(&i), true); // should not block or panic
CloseChannel(ch);
EXPECT_EQ(ch->Send(&i), false); // should panic
delete ch;
}
TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) {
const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size);
......@@ -88,7 +98,6 @@ TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) {
// Note: we cannot check EXPECT_EQ(out, 0), because C++ doesn't
// define zero values like Go does.
}
delete ch;
}
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <condition_variable>
#include <deque>
#include <mutex>
......@@ -42,7 +43,7 @@ class Buffered : public paddle::framework::Channel<T> {
std::condition_variable empty_cond_var_;
std::condition_variable full_cond_var_;
std::deque<T> channel_;
bool closed_;
std::atomic<bool> closed_{false};
Buffered(size_t cap) : cap_(cap), closed_(false) {
PADDLE_ENFORCE_GT(cap, 0);
......@@ -53,10 +54,13 @@ class Buffered : public paddle::framework::Channel<T> {
template <typename T>
bool Buffered<T>::Send(T* item) {
bool ret = false;
if (closed_) {
return ret;
}
std::unique_lock<std::mutex> lock(mu_);
full_cond_var_.wait(lock,
[this]() { return channel_.size() < cap_ || closed_; });
bool ret = false;
if (!closed_) {
channel_.push_back(std::move(*item));
lock.unlock();
......@@ -82,6 +86,9 @@ bool Buffered<T>::Receive(T* item) {
template <typename T>
void Buffered<T>::Close() {
if (closed_) {
return;
}
std::unique_lock<std::mutex> lock(mu_);
closed_ = true;
NotifyAllParticipants(&lock);
......
......@@ -58,6 +58,10 @@ class UnBuffered : public paddle::framework::Channel<T> {
// be sent from a writer to a reader.
template <typename T>
bool UnBuffered<T>::Send(T* data) {
bool ret = false;
if (closed_) {
return ret;
}
// Prevent other writers from entering
std::unique_lock<std::recursive_mutex> writer_lock(mu_write_);
writer_found_ = true;
......@@ -66,7 +70,6 @@ bool UnBuffered<T>::Send(T* data) {
cv_writer_.wait(cv_lock,
[this]() { return reader_found_ == true || closed_; });
cv_reader_.notify_one();
bool ret = false;
if (!closed_) {
std::unique_lock<std::mutex> channel_lock(mu_ch_);
item = data;
......@@ -114,6 +117,9 @@ bool UnBuffered<T>::Receive(T* data) {
// that take place once the channel is closed.
template <typename T>
void UnBuffered<T>::Close() {
if (closed_) {
return;
}
std::unique_lock<std::mutex> lock(mu_ch_);
item = nullptr;
closed_ = true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册