diff --git a/paddle/operators/detail/channel.h b/paddle/operators/detail/channel.h new file mode 100644 index 0000000000000000000000000000000000000000..cbfdf800401e8e92bead6f23ff8f1417e35989aa --- /dev/null +++ b/paddle/operators/detail/channel.h @@ -0,0 +1,89 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +namespace paddle { +namespace operators { +namespace detail { + +template +class Channel { + public: + explicit Channel(std::size_t capacity) : capacity_(capacity) {} + + void Send(T* channel_element) { + std::unique_lock lock(mu_); + + if (IsBounded()) { + full_cond_var_.wait(lock, [this]() { + bool capacity_valid = capacity_ > 0 ? !IsCapacityFull() : true; + return capacity_valid; + }); + } + channel_.push_back(std::move(*channel_element)); + + lock.unlock(); + empty_cond_var_.notify_all(); + } + + T* Receive() { + std::unique_lock lock(mu_); + empty_cond_var_.wait(lock, [this]() { return !channel_.empty(); }); + + T* channel_element = std::move(channel_.front()); + channel_.pop_front(); + + NotifyAllSenders(&lock); + return channel_element; + } + + size_t Size() { + std::unique_lock lock(mu_); + return channel_.size(); + } + + void Clear() { + std::unique_lock lock(mu_); + channel_.clear(); + + NotifyAllSenders(&lock); + } + + private: + std::size_t capacity_; + std::mutex mu_; + std::condition_variable empty_cond_var_; + std::condition_variable full_cond_var_; + std::deque channel_; + + private: + void NotifyAllSenders(std::unique_lock* lock) { + if (IsBounded()) { + lock->unlock(); + full_cond_var_.notify_all(); + } + } + + bool IsBounded() const { return capacity_ > 0; } + + bool IsCapacityFull() const { return channel_.size() >= capacity_; } +}; + +} // namespace detail +} // namespace operator +} // namespace paddle