// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // This file implements the variant data structure similar to // absl::variant in C++17. #pragma once #include #include #include #include #include "paddle/infrt/support/type_traits.h" namespace infrt { // A Variant similar to absl::variant in C++17. // // Example usage: // // Variant v; // // v = 1; // assert(v.get() == 1); // assert(v.is()); // assert(v.get_if() == nullptr); // // // Print the variant. // visit([](auto& t) { std::cout << t; }, v); // // v.emplace(3); // template class Variant { // Convenient constant to check if a type is a variant. template static constexpr bool IsVariant = std::is_same, Variant>::value; public: using IndexT = int16_t; using Types = std::tuple; template using TypeOf = typename std::tuple_element::type; static constexpr size_t kNTypes = sizeof...(Ts); // Default constructor sets the Variant to the default constructed fisrt type. Variant() { using Type0 = TypeOf<0>; index_ = 0; new (&storage_) Type0(); } template , int> = 0> explicit Variant(T&& t) { fillValue(std::forward(t)); } Variant(const Variant& v) { visit([this](auto& t) { fillValue(t); }, v); } Variant(Variant&& v) { visit([this](auto&& t) { fillValue(std::move(t)); }, v); } ~Variant() { destroy(); } Variant& operator=(Variant&& v) { visit([this](auto& t) { *this = std::move(t); }, v); return *this; } Variant& operator=(const Variant& v) { visit([this](auto& t) { *this = t; }, v); return *this; } template , int> = 0> Variant& operator=(T&& t) { destroy(); fillValue(std::forward(t)); return *this; } template T& emplace(Args&&... args) { AssertHasType(); destroy(); index_ = IndexOf; auto* t = new (&storage_) T(std::forward(args)...); return *t; } template bool is() const { AssertHasType(); return IndexOf == index_; } template const T& get() const { AssertHasType(); return *reinterpret_cast(&storage_); } template T& get() { AssertHasType(); return *reinterpret_cast(&storage_); } template const T* get_if() const { if (is()) return &get(); return nullptr; } template T* get_if() { if (is()) return &get(); return nullptr; } IndexT index() { return index_; } private: template static constexpr size_t IndexOf = TupleIndexOf::value; static constexpr size_t kStorageSize = std::max({sizeof(Ts)...}); static constexpr size_t kAlignment = std::max({alignof(Ts)...}); template static constexpr void AssertHasType() { constexpr bool has_type = TupleHasType::value; static_assert(has_type, "Invalid Type used for Variant"); } void destroy() { visit( [](auto& t) { using T = std::decay_t; t.~T(); }, *this); } template void fillValue(T&& t) { using Type = std::decay_t; AssertHasType(); index_ = IndexOf; new (&storage_) Type(std::forward(t)); } using StorageT = std::aligned_storage_t; StorageT storage_; IndexT index_ = -1; }; struct Monostate {}; namespace internal { template decltype(auto) visitHelper( F&& f, Variant&& v, std::integral_constant::kNTypes>) { assert(false && "Unexpected index_ in Variant"); } // Disable clang-format as it does not format less-than (<) in the template // parameter properly. // // clang-format off template < typename F, typename Variant, int N, std::enable_if_t::kNTypes, int> = 0> decltype(auto) visitHelper(F&& f, Variant&& v, std::integral_constant) { // clang-format on using VariantT = std::decay_t; using T = typename VariantT::template TypeOf; if (auto* t = v.template get_if()) { return f(*t); } else { return visitHelper(std::forward(f), std::forward(v), std::integral_constant()); } } } // namespace internal template decltype(auto) visit(F&& f, Variant&& v) { return internal::visitHelper(std::forward(f), std::forward(v), std::integral_constant()); } } // namespace infrt