// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #if defined _WIN32 || defined __APPLE__ #else #define _LINUX #endif #include #include #include // NOLINT #include #include #include #include // NOLINT #include #include #include "paddle/fluid/framework/expect.h" namespace paddle { namespace framework { template class ChannelObject { public: ChannelObject() {} // capacity can be zero explicit ChannelObject(size_t capacity) { capacity_ = (std::min)(MaxCapacity(), capacity); } const std::deque& GetData() const { return data_; } void Clear() { std::unique_lock lock(mutex_); data_.clear(); data_.shrink_to_fit(); } size_t Capacity() { return capacity_; // atomic } void SetCapacity(size_t x) { // capacity can be zero std::lock_guard lock(mutex_); capacity_ = std::min(MaxCapacity(), x); Notify(); } size_t BlockSize() { return block_size_; // atomic } void SetBlockSize(size_t x) { CHECK(x >= 1) << "block size must be >= 1"; std::lock_guard lock(mutex_); block_size_ = x; } template void InheritFrom(const std::shared_ptr>& other) { std::lock_guard lock(mutex_); capacity_ = other->Capacity(); block_size_ = other->BlockSize(); } bool Closed() { return closed_; // atomic } // open channel, then data can be write() to channel void Open() { std::lock_guard lock(mutex_); closed_ = false; Notify(); } // close channel, then no more data can be write() to channel void Close() { std::lock_guard lock(mutex_); closed_ = true; Notify(); } size_t Size() { std::lock_guard lock(mutex_); return data_.size(); } bool Empty() { std::lock_guard lock(mutex_); return EmptyUnlocked(); } // blocking operation bool Get(T& val) { return Read(1, &val) != 0; } // NOLINT // blocking operation // returns 0 if the channel is closed and empty size_t Read(size_t n, T* p) { if (n == 0) { return 0; } std::unique_lock lock(mutex_); size_t finished = Read(n, p, lock); Notify(); return finished; } // blocking operation bool Put(T&& val) { return WriteMove(1, &val) != 0; } // blocking operation bool Put(const T& val) { return Write(1, &val) != 0; } // blocking operation // returns value less than n if the channel is closed size_t Write(size_t n, const T* p) { if (n == 0) { return 0; } std::unique_lock lock(mutex_); size_t finished = Write(n, p, lock); Notify(); return finished; } // WriteMove() will clear original contents of input array size_t WriteMove(size_t n, T* p) { if (n == 0) { return 0; } std::unique_lock lock(mutex_); size_t finished = WriteMove(n, p, lock); Notify(); return finished; } // read data of block size from channel to vector size_t Read(std::vector& p) { // NOLINT p.resize(block_size_); size_t finished = Read(p.size(), &p[0]); p.resize(finished); return finished; } size_t ReadAll(std::vector& p) { // NOLINT p.clear(); size_t finished = 0; size_t n = 0; do { // _block_size may change anytime n = block_size_; p.resize(finished + n); n = Read(n, &p[finished]); finished += n; } while (n != 0); p.resize(finished); return finished; } // write data from vector to channel size_t Write(const std::vector& p) { return Write(p.size(), &p[0]); } // write data from vector to channel size_t Write(std::vector&& p) { return WriteMove(p.size(), &p[0]); } private: size_t capacity_ = MaxCapacity(); size_t block_size_ = 1024; bool closed_ = false; std::mutex mutex_; // use deque to store data std::deque data_; size_t reading_count_ = 0; int empty_waiters_ = 0; int full_waiters_ = 0; std::condition_variable empty_cond_; std::condition_variable full_cond_; static constexpr size_t MaxCapacity() { return (std::numeric_limits::max)() / 2; } void Notify() { if (empty_waiters_ != 0 && (!EmptyUnlocked() || closed_)) { empty_cond_.notify_one(); } if (full_waiters_ != 0 && (!FullUnlocked() || closed_)) { full_cond_.notify_one(); } } bool EmptyUnlocked() { return data_.empty(); } bool FullUnlocked() { return data_.size() >= capacity_ + reading_count_; } bool WaitForRead(std::unique_lock& lock) { // NOLINT #ifdef _LINUX while (unlikely(EmptyUnlocked() && !closed_)) { #else while (EmptyUnlocked() && !closed_) { #endif if (full_waiters_ != 0) { full_cond_.notify_one(); } empty_waiters_++; empty_cond_.wait(lock); empty_waiters_--; } return !EmptyUnlocked(); } bool WaitForWrite(std::unique_lock& lock) { // NOLINT #ifdef _LINUX while (unlikely(FullUnlocked() && !closed_)) { #else while (FullUnlocked() && !closed_) { #endif if (empty_waiters_ != 0) { empty_cond_.notify_one(); } full_waiters_++; full_cond_.wait(lock); full_waiters_--; } return !closed_; } size_t Read(size_t n, T* p, std::unique_lock& lock) { // NOLINT size_t finished = 0; CHECK(n <= MaxCapacity() - reading_count_); reading_count_ += n; while (finished < n && WaitForRead(lock)) { size_t m = std::min(n - finished, data_.size()); for (size_t i = 0; i < m; i++) { p[finished++] = std::move(data_.front()); data_.pop_front(); } reading_count_ -= m; } reading_count_ -= n - finished; return finished; } size_t Write(size_t n, const T* p, // NOLINT std::unique_lock& lock) { // NOLINT size_t finished = 0; while (finished < n && WaitForWrite(lock)) { size_t m = std::min(n - finished, capacity_ + reading_count_ - data_.size()); for (size_t i = 0; i < m; i++) { data_.push_back(p[finished++]); } } return finished; } size_t WriteMove(size_t n, T* p, // NOLINT std::unique_lock& lock) { // NOLINT size_t finished = 0; while (finished < n && WaitForWrite(lock)) { size_t m = (std::min)(n - finished, capacity_ + reading_count_ - data_.size()); for (size_t i = 0; i < m; i++) { data_.push_back(std::move(p[finished++])); } } return finished; } }; // NOLINT template using Channel = std::shared_ptr>; template Channel MakeChannel(size_t capacity = (std::numeric_limits::max)()) { return std::make_shared>(capacity); } template Channel MakeChannel(const Channel& other) { CHECK(other != nullptr) << "channel can not be NULL"; Channel chan = std::make_shared>(); chan->InheritFrom(other); return chan; } // NOTE: ChannelReader is a wrapper for quick read channel with a buffer. It // will read a block data from channel, but user can get data one by one. So it // is important to notice that user must call operator>> until false, or call // get_buffer_remain until false to make sure the buffered data all readed. template class ChannelReader { public: explicit ChannelReader(ChannelObject* channel = nullptr) { Reset(channel); } ~ChannelReader() { CHECK(cursor_ == 0) << "Forgot to read buffer data"; } ChannelObject* channel() { return channel_; } void Reset(ChannelObject* channel) { CHECK(channel != nullptr) << "Channel can not be nullptr"; channel_ = channel; cursor_ = 0; failed_ = !channel; } // whether there were read failed operator bool() { return !failed_; } ChannelReader& operator>>(T& val) { if (failed_) { return *this; } if (cursor_ >= buffer_.size()) { cursor_ = 0; if (channel_->read(buffer_) == 0) { failed_ = true; return *this; } } val = std::move(buffer_[cursor_++]); return *this; } bool GetBufferRemain(T& val) { // NOLINT if (cursor_ >= buffer_.size()) { cursor_ = 0; return false; } val = std::move(buffer_[cursor_++]); return true; } private: ChannelObject* channel_ = nullptr; std::vector buffer_; size_t cursor_ = 0; bool failed_ = true; }; // NOLINT template class ChannelWriter { public: explicit ChannelWriter(ChannelObject* channel = nullptr) { Reset(channel); } ~ChannelWriter() { CHECK(buffer_.empty()) << "Forgot to flush"; } ChannelObject* channel() { return channel_; } void Reset(ChannelObject* channel) { CHECK(buffer_.empty()) << "Forgot to flush"; // CHECK(channel != nullptr) << "Channel can not be nullptr"; channel_ = channel; buffer_.clear(); failed_ = !channel; } // whether there were write failed operator bool() { return !failed_; } ChannelWriter& operator<<(T&& val) { if (failed_) { return *this; } buffer_.push_back(std::move(val)); if (buffer_.size() >= channel_->BlockSize()) { Flush(); } return *this; } ChannelWriter& operator<<(const T& val) { if (failed_) { return *this; } buffer_.push_back(val); if (buffer_.size() >= channel_->BlockSize()) { Flush(); } return *this; } void Flush() { if (failed_ || buffer_.empty()) { buffer_.clear(); return; } failed_ |= channel_->WriteMove(buffer_.size(), &buffer_[0]) != buffer_.size(); buffer_.clear(); } private: ChannelObject* channel_ = nullptr; std::vector buffer_; bool failed_ = true; }; // NOLINT // only used for range-for loop // for (auto& x : chan) {...} template struct ChannelIterator { std::shared_ptr> reader_; T data_; void operator++() { CHECK(reader_ != nullptr) << "reader can not be NULL"; if (!(*reader_ >> data_)) { reader_ = nullptr; } } T& operator*() { return data_; } friend bool operator==(const ChannelIterator& a, const ChannelIterator& b) { return a.reader_ == b.reader_; } friend bool operator!=(const ChannelIterator& a, const ChannelIterator& b) { return a.reader_ != b.reader_; } }; // NOLINT template ChannelIterator begin(ChannelObject* chan) { ChannelIterator it{std::make_shared>(chan), T()}; ++it; return it; } template ChannelIterator end(ChannelObject* chan) { return {nullptr, T()}; } } // namespace framework } // namespace paddle