// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #if defined _WIN32 || defined __APPLE__ #else #define _LINUX #endif #include #include #include #include #include #include #include #include #include #include #include #include #include "paddle/fluid/framework/expect.h" namespace paddle { namespace framework { // not a virtual class class ArchiveBase { protected: ArchiveBase() {} // Archive is not copyable. But to allow move capture by function objects, // check it at runtime rather than at compile time. ArchiveBase(const ArchiveBase&) { LOG(FATAL) << "Not supported"; } ArchiveBase(ArchiveBase&& other) : buffer_(other.buffer_), cursor_(other.cursor_), finish_(other.finish_), limit_(other.limit_), deleter_(std::move(other.deleter_)) { other.buffer_ = NULL; other.cursor_ = NULL; other.finish_ = NULL; other.limit_ = NULL; other.deleter_ = nullptr; } ~ArchiveBase() { FreeBuffer(); } public: ArchiveBase& operator=(const ArchiveBase&) { LOG(FATAL) << "Not supported"; return *this; } ArchiveBase& operator=(ArchiveBase&& other) { if (this != &other) { FreeBuffer(); buffer_ = other.buffer_; cursor_ = other.cursor_; finish_ = other.finish_; limit_ = other.limit_; deleter_ = std::move(other.deleter_); other.buffer_ = NULL; other.cursor_ = NULL; other.finish_ = NULL; other.limit_ = NULL; other.deleter_ = nullptr; } return *this; } char* Buffer() { return buffer_; } void SetReadBuffer(char* buffer, size_t length, std::function&& deleter) { SetBuffer(buffer, length, length, std::move(deleter)); } void SetWriteBuffer(char* buffer, size_t capacity, std::function&& deleter) { SetBuffer(buffer, 0, capacity, std::move(deleter)); } void SetBuffer(char* buffer, size_t length, size_t capacity, std::function&& deleter) { CHECK(length <= capacity); FreeBuffer(); buffer_ = buffer; cursor_ = buffer_; finish_ = buffer + length; limit_ = buffer + capacity; deleter_ = std::move(deleter); } char* Cursor() { return cursor_; } void SetCursor(char* cursor) { CHECK(cursor >= buffer_ && cursor <= finish_); cursor_ = cursor; } void AdvanceCursor(size_t offset) { CHECK(offset <= size_t(finish_ - cursor_)); cursor_ += offset; } char* Finish() { return finish_; } void SetFinish(char* finish) { CHECK(finish >= cursor_ && finish <= limit_); finish_ = finish; } void AdvanceFinish(size_t offset) { CHECK(offset <= size_t(limit_ - finish_)); finish_ += offset; } char* Limit() { return limit_; } size_t Position() { return cursor_ - buffer_; } size_t Length() { return finish_ - buffer_; } size_t Capacity() { return limit_ - buffer_; } bool Empty() { return finish_ == buffer_; } void Reset() { FreeBuffer(); buffer_ = NULL; cursor_ = NULL; finish_ = NULL; limit_ = NULL; } void Clear() { cursor_ = buffer_; finish_ = buffer_; } char* Release() { char* buf = buffer_; buffer_ = NULL; cursor_ = NULL; finish_ = NULL; deleter_ = nullptr; return buf; } void Resize(size_t newsize) { #ifdef _LINUX if (unlikely(newsize > Capacity())) { #else if (newsize > Capacity()) { #endif Reserve((std::max)(Capacity() * 2, newsize)); } finish_ = buffer_ + newsize; cursor_ = (std::min)(cursor_, finish_); } void Reserve(size_t newcap) { if (newcap > Capacity()) { char* newbuf = NULL; newbuf = new char[newcap]; CHECK(newbuf != nullptr) << "Reserve failed, out of memory"; if (Length() > 0) { memcpy(newbuf, buffer_, Length()); } cursor_ = newbuf + (cursor_ - buffer_); finish_ = newbuf + (finish_ - buffer_); limit_ = newbuf + newcap; FreeBuffer(); buffer_ = newbuf; deleter_ = std::default_delete(); } } void PrepareRead(size_t size) { #ifdef _LINUX if (unlikely(!(size <= size_t(finish_ - cursor_)))) { #else if (!(size <= size_t(finish_ - cursor_))) { #endif CHECK(size <= size_t(finish_ - cursor_)); } } void PrepareWrite(size_t size) { #ifdef _LINUX if (unlikely(size > size_t(limit_ - finish_))) { #else if (size > size_t(limit_ - finish_)) { #endif Reserve((std::max)(Capacity() * 2, Length() + size)); } } void Read(void* data, size_t size) { if (size > 0) { PrepareRead(size); memcpy(data, cursor_, size); AdvanceCursor(size); } } void ReadBack(void* data, size_t size) { if (size > 0) { CHECK(size <= size_t(finish_ - cursor_)); memcpy(data, finish_ - size, size); finish_ -= size; } } void Write(const void* data, size_t size) { if (size > 0) { PrepareWrite(size); memcpy(finish_, data, size); AdvanceFinish(size); } } template void GetRaw(T& x) { // NOLINT PrepareRead(sizeof(T)); memcpy(&x, cursor_, sizeof(T)); AdvanceCursor(sizeof(T)); } template T GetRaw() { T x; GetRaw(x); return x; } template void PutRaw(const T& x) { PrepareWrite(sizeof(T)); memcpy(finish_, &x, sizeof(T)); AdvanceFinish(sizeof(T)); } protected: char* buffer_ = NULL; char* cursor_ = NULL; char* finish_ = NULL; char* limit_ = NULL; std::function deleter_ = nullptr; void FreeBuffer() { if (deleter_) { deleter_(buffer_); } deleter_ = nullptr; } }; // NOLINT template class Archive {}; class BinaryArchiveType {}; typedef Archive BinaryArchive; template <> class Archive : public ArchiveBase { public: #define ARCHIVE_REPEAT(T) \ BinaryArchive& operator>>(T& x) { \ GetRaw(x); \ return *this; \ } \ BinaryArchive& operator<<(const T& x) { \ PutRaw(x); \ return *this; \ } ARCHIVE_REPEAT(int16_t) ARCHIVE_REPEAT(uint16_t) ARCHIVE_REPEAT(int32_t) ARCHIVE_REPEAT(uint32_t) ARCHIVE_REPEAT(int64_t) ARCHIVE_REPEAT(uint64_t) ARCHIVE_REPEAT(float) ARCHIVE_REPEAT(double) ARCHIVE_REPEAT(signed char) ARCHIVE_REPEAT(unsigned char) ARCHIVE_REPEAT(bool) #undef ARCHIVE_REPEAT template T Get() { T x; *this >> x; return x; } template void Printf(const char* fmt, ARGS&&... args) { size_t temp = Limit() - Finish(); int len = snprintf(Finish(), temp, fmt, args...); CHECK(len >= 0); // NOLINT if ((size_t)len >= temp) { PrepareWrite(len + 1); CHECK(snprintf(Finish(), (size_t)len + 1, fmt, args...) == len); } AdvanceFinish(len); } }; template Archive& operator<<(Archive& ar, const T (&p)[N]) { for (size_t i = 0; i < N; i++) { ar << p[i]; } return ar; } template Archive& operator>>(Archive& ar, T (&p)[N]) { for (size_t i = 0; i < N; i++) { ar >> p[i]; } return ar; } template Archive& operator<<(Archive& ar, const std::vector& p) { #ifdef _LINUX ar << (size_t)p.size(); #else ar << (uint64_t)p.size(); #endif for (const auto& x : p) { ar << x; } return ar; } template Archive& operator>>(Archive& ar, std::vector& p) { #ifdef _LINUX p.resize(ar.template Get()); #else p.resize(ar.template Get()); #endif for (auto& x : p) { ar >> x; } return ar; } template Archive& operator<<(Archive& ar, const std::valarray& p) { #ifdef _LINUX ar << (size_t)p.size(); #else ar << (uint64_t)p.size(); #endif for (const auto& x : p) { ar << x; } return ar; } template Archive& operator>>(Archive& ar, std::valarray& p) { #ifdef _LINUX p.resize(ar.template Get()); #else p.resize(ar.template Get()); #endif for (auto& x : p) { ar >> x; } return ar; } inline BinaryArchive& operator<<(BinaryArchive& ar, const std::string& s) { #ifdef _LINUX ar << (size_t)s.length(); #else ar << (uint64_t)s.length(); #endif ar.Write(&s[0], s.length()); return ar; } inline BinaryArchive& operator>>(BinaryArchive& ar, std::string& s) { #ifdef _LINUX size_t len = ar.template Get(); #else size_t len = ar.template Get(); #endif ar.PrepareRead(len); s.assign(ar.Cursor(), len); ar.AdvanceCursor(len); return ar; } template Archive& operator<<(Archive& ar, const std::pair& x) { return ar << x.first << x.second; } template Archive& operator>>(Archive& ar, std::pair& x) { // NOLINT return ar >> x.first >> x.second; } #ifdef _LINUX template Archive& SerializeTuple(Archive& ar, // NOLINT const std::tuple& x, // NOLINT std::integral_constant n) { // NOLINT return ar; } #else template Archive& SerializeTuple(Archive& ar, // NOLINT const std::tuple& x, // NOLINT std::integral_constant n) { // NOLINT return ar; } #endif #ifdef _LINUX template Archive& serialize_tuple(Archive& ar, // NOLINT const std::tuple& x, // NOLINT std::integral_constant n) { // NOLINT return SerializeTuple(ar, x, std::integral_constant()) << std::get(x); } #else template Archive& serialize_tuple(Archive& ar, // NOLINT const std::tuple& x, // NOLINT std::integral_constant n) { // NOLINT return SerializeTuple(ar, x, std::integral_constant()) << std::get(x); } #endif #ifdef _LINUX template Archive& operator<<(Archive& ar, const std::tuple& x) { const size_t size = std::tuple_size>::value; return SerializeTuple(ar, x, std::integral_constant()); } #else template Archive& operator<<(Archive& ar, const std::tuple& x) { const uint64_t size = std::tuple_size>::value; return SerializeTuple(ar, x, std::integral_constant()); } #endif #ifdef _LINUX template Archive& DeserializeTuple(Archive& ar, std::tuple& x, // NOLINT std::integral_constant n) { return ar; } #else template Archive& DeserializeTuple(Archive& ar, std::tuple& x, // NOLINT std::integral_constant n) { return ar; } #endif #ifdef _LINUX template Archive& DeserializeTuple(Archive& ar, std::tuple& x, // NOLINT std::integral_constant n) { return DeserializeTuple(ar, x, std::integral_constant()) >> std::get(x); } #else template Archive& DeserializeTuple(Archive& ar, std::tuple& x, // NOLINT std::integral_constant n) { return DeserializeTuple(ar, x, std::integral_constant()) >> std::get(x); } #endif #ifdef _LINUX template Archive& operator>>(Archive& ar, std::tuple& x) { const size_t size = std::tuple_size>::value; return DeserializeTuple(ar, x, std::integral_constant()); } #else template Archive& operator>>(Archive& ar, std::tuple& x) { const uint64_t size = std::tuple_size>::value; return DeserializeTuple(ar, x, std::integral_constant()); } #endif #ifdef _LINUX #define ARCHIVE_REPEAT(MAP_TYPE, RESERVE_STATEMENT) \ template \ Archive& operator<<(Archive& ar, \ const MAP_TYPE& p) { \ ar << (size_t)p.size(); \ for (auto it = p.begin(); it != p.end(); ++it) { \ ar << *it; \ } \ return ar; \ } \ template \ Archive& operator>>(Archive& ar, MAP_TYPE& p) { \ size_t size = ar.template get(); \ p.clear(); \ RESERVE_STATEMENT; \ for (size_t i = 0; i < size; i++) { \ p.insert(ar.template get>()); \ } \ return ar; \ } #else #define ARCHIVE_REPEAT(MAP_TYPE, RESERVE_STATEMENT) \ template \ Archive& operator<<(Archive& ar, \ const MAP_TYPE& p) { \ ar << (uint64_t)p.size(); \ for (auto it = p.begin(); it != p.end(); ++it) { \ ar << *it; \ } \ return ar; \ } \ template \ Archive& operator>>(Archive& ar, MAP_TYPE& p) { \ size_t size = ar.template get(); \ p.clear(); \ RESERVE_STATEMENT; \ for (size_t i = 0; i < size; i++) { \ p.insert(ar.template get>()); \ } \ return ar; \ } #endif ARCHIVE_REPEAT(std::map, ) ARCHIVE_REPEAT(std::multimap, ) ARCHIVE_REPEAT(std::unordered_map, p.reserve(size)) ARCHIVE_REPEAT(std::unordered_multimap, p.reserve(size)) #undef ARCHIVE_REPEAT #ifdef _LINUX #define ARCHIVE_REPEAT(SET_TYPE, RESERVE_STATEMENT) \ template \ Archive& operator<<(Archive& ar, const SET_TYPE& p) { \ ar << (size_t)p.size(); \ for (auto it = p.begin(); it != p.end(); ++it) { \ ar << *it; \ } \ return ar; \ } \ template \ Archive& operator>>(Archive& ar, SET_TYPE& p) { \ size_t size = ar.template get(); \ p.clear(); \ RESERVE_STATEMENT; \ for (size_t i = 0; i < size; i++) { \ p.insert(ar.template get()); \ } \ return ar; \ } #else #define ARCHIVE_REPEAT(SET_TYPE, RESERVE_STATEMENT) \ template \ Archive& operator<<(Archive& ar, const SET_TYPE& p) { \ ar << (uint64_t)p.size(); \ for (auto it = p.begin(); it != p.end(); ++it) { \ ar << *it; \ } \ return ar; \ } \ template \ Archive& operator>>(Archive& ar, SET_TYPE& p) { \ size_t size = ar.template get(); \ p.clear(); \ RESERVE_STATEMENT; \ for (size_t i = 0; i < size; i++) { \ p.insert(ar.template get()); \ } \ return ar; \ } #endif ARCHIVE_REPEAT(std::set, ) ARCHIVE_REPEAT(std::multiset, ) ARCHIVE_REPEAT(std::unordered_set, p.reserve(size)) ARCHIVE_REPEAT(std::unordered_multiset, p.reserve(size)) #undef ARCHIVE_REPEAT } // namespace framework } // namespace paddle