From ab24b9c05d20a6d54a913d11bf3f5c5c9511408e Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 26 Apr 2022 09:33:41 +0800 Subject: [PATCH] [Cherry-pick] Optimize dygraph performance part2 (#42224) * Add paddle::variant and replace paddle::any (#42139) * add variant and replace any * split attribute * Optimize dygraph GetExpectedKernelType perf (#42154) * opt dygraph scheduling * revert part impl * fix variant compile error (#42203) * replace any by variant in infermeta (#42181) --- paddle/fluid/framework/custom_operator.cc | 1 + .../framework/new_executor/interpretercore.cc | 1 + .../new_executor/interpretercore_util.cc | 1 + paddle/fluid/framework/operator.cc | 48 +- paddle/fluid/framework/operator.h | 17 +- paddle/fluid/imperative/execution_context.h | 18 +- paddle/fluid/imperative/prepared_operator.h | 1 + paddle/fluid/operators/transpose_op.cc | 2 +- paddle/phi/core/attribute.h | 50 + paddle/phi/core/infermeta_utils.cc | 34 +- paddle/phi/core/infermeta_utils.h | 60 +- paddle/phi/core/kernel_context.cc | 32 +- paddle/phi/core/kernel_context.h | 21 +- paddle/phi/core/kernel_registry.h | 10 + paddle/phi/core/kernel_utils.h | 39 +- paddle/phi/core/type_defs.h | 2 + paddle/phi/infermeta/unary.cc | 9 +- paddle/phi/infermeta/unary.h | 5 - paddle/phi/kernels/cpu/where_grad_kernel.cc | 2 + paddle/phi/kernels/cpu/where_kernel.cc | 2 + paddle/phi/kernels/funcs/activation_functor.h | 2 +- paddle/phi/kernels/gpu/where_grad_kernel.cu | 3 + paddle/phi/kernels/gpu/where_kernel.cu | 2 + paddle/phi/kernels/where_grad_kernel.h | 3 - paddle/phi/kernels/where_kernel.h | 3 - paddle/phi/tests/core/test_custom_kernel.cc | 5 - paddle/phi/tests/core/test_meta_fn_utils.cc | 26 - paddle/utils/variant.h | 2830 +++++++++++++++++ 28 files changed, 3109 insertions(+), 120 deletions(-) create mode 100644 paddle/phi/core/attribute.h create mode 100644 paddle/utils/variant.h diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 3f28b2e8c7..65c41e19ac 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -39,6 +39,7 @@ limitations under the License. */ #include "paddle/phi/api/all.h" #include "paddle/phi/api/lib/utils/tensor_utils.h" #include "paddle/phi/core/compat/convert_utils.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/utils/any.h" namespace paddle { diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 67c79c9bf9..edc066ac55 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -22,6 +22,7 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/os_info.h" #include "paddle/fluid/platform/profiler/event_tracing.h" +#include "paddle/phi/core/kernel_context.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index ed813c78bc..81b1c159ef 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h" #include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h" +#include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_factory.h" #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 1dd47873c0..279422eb35 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -35,6 +35,7 @@ limitations under the License. */ #include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/ops/compat/signatures.h" @@ -941,7 +942,7 @@ class RuntimeInferShapeContext : public InferShapeContext { return ((op_with_kernel.kernel_type()) && (op_with_kernel.kernel_type()->data_layout_ == framework::DataLayout::kMKLDNN)); - } catch (std::bad_cast exp) { + } catch (const std::bad_cast& exp) { return false; } } @@ -1966,6 +1967,36 @@ Scope* OperatorWithKernel::PrepareData( } void OperatorWithKernel::ParseInputDataType( + const Variable* var, const std::string& name, + proto::VarType::Type* data_type) const { + if (var != nullptr) { + const Tensor* t = nullptr; + if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &(var->Get().value()); + } else if (var->IsType()) { + auto t_arr = &var->Get(); + for (size_t j = 0; j < t_arr->size(); j++) { + if (t_arr->at(j).IsInitialized()) { + t = &(t_arr->at(j)); + } + } + } + if (t != nullptr) { + PADDLE_ENFORCE_EQ( + t->IsInitialized(), true, + platform::errors::InvalidArgument("The %s Op's Input Variable `%s` " + "contains uninitialized Tensor.", + Type(), name)); + *data_type = paddle::framework::TransToProtoVarType(t->dtype()); + } + } +} + +void OperatorWithKernel::ParseMultiInputDataType( const std::vector& vars, const std::string& name, proto::VarType::Type* data_type) const { proto::VarType::Type default_data_type = @@ -2016,9 +2047,12 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( proto::VarType::Type dafault_data_type = static_cast(-1); proto::VarType::Type data_type = dafault_data_type; - for (auto& input : ctx.InNameList()) { - const std::vector vars = ctx.MultiInputVar(input); - ParseInputDataType(vars, input, &data_type); + for (auto* name : ctx.InNameList()) { + if (ctx.InputSize(*name) == 1UL) { + ParseInputDataType(ctx.InputVar(*name), *name, &data_type); + } else { + ParseMultiInputDataType(ctx.MultiInputVar(*name), *name, &data_type); + } } PADDLE_ENFORCE_NE( data_type, dafault_data_type, @@ -2032,7 +2066,11 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType( proto::VarType::Type dafault_data_type = static_cast(-1); proto::VarType::Type data_type = dafault_data_type; - ParseInputDataType(ctx.MultiInputVar(name), name, &data_type); + if (ctx.InputSize(name) == 1UL) { + ParseInputDataType(ctx.InputVar(name), name, &data_type); + } else { + ParseMultiInputDataType(ctx.MultiInputVar(name), name, &data_type); + } PADDLE_ENFORCE_NE( data_type, dafault_data_type, platform::errors::InvalidArgument( diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index f0887eb919..dd21be12f4 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -43,7 +43,6 @@ limitations under the License. */ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/op_utils.h" -#include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_factory.h" namespace paddle { @@ -55,6 +54,10 @@ class Variable; } // namespace framework } // namespace paddle +namespace phi { +class KernelContext; +} + DECLARE_int32(inner_op_parallelism); namespace paddle { @@ -330,12 +333,12 @@ class ExecutionContext { return it->second; } - virtual std::vector InNameList() const { - std::vector vec_temp; + virtual paddle::SmallVector InNameList() const { + paddle::SmallVector vec_temp; vec_temp.reserve(ctx_.inputs.size()); for (auto& input : ctx_.inputs) { - vec_temp.push_back(input.first); + vec_temp.push_back(&input.first); } return vec_temp; @@ -677,9 +680,11 @@ class OperatorWithKernel : public OperatorBase { // By default all input data must be same. proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; // used for IndicateDataType - void ParseInputDataType(const std::vector& vars, - const std::string& name, + void ParseInputDataType(const Variable* vars, const std::string& name, proto::VarType::Type* data_type) const; + void ParseMultiInputDataType(const std::vector& vars, + const std::string& name, + proto::VarType::Type* data_type) const; // used for IndicateOrPromoteVarDataTypes Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx, const std::string& name) const; diff --git a/paddle/fluid/imperative/execution_context.h b/paddle/fluid/imperative/execution_context.h index fbc47f81fd..330a5a0cfa 100644 --- a/paddle/fluid/imperative/execution_context.h +++ b/paddle/fluid/imperative/execution_context.h @@ -117,12 +117,12 @@ class DygraphExecutionContext : public framework::ExecutionContext { return it->second; } - std::vector InNameList() const override { - std::vector vec_temp; + paddle::SmallVector InNameList() const override { + paddle::SmallVector vec_temp; vec_temp.reserve(var_map_in_.size()); for (auto& v : var_map_in_) { - vec_temp.push_back(v.first); + vec_temp.push_back(&v.first); } return vec_temp; @@ -144,11 +144,19 @@ class DygraphExecutionContext : public framework::ExecutionContext { } size_t InputSize(const std::string& name) const override { - return InputNames(name).size(); + auto it = var_map_in_.find(name); + PADDLE_ENFORCE_NE( + it, var_map_in_.end(), + platform::errors::NotFound("Can not find [%s] in Input", name)); + return it->second.size(); } size_t OutputSize(const std::string& name) const override { - return OutputNames(name).size(); + auto it = var_map_out_.find(name); + PADDLE_ENFORCE_NE( + it, var_map_out_.end(), + platform::errors::NotFound("Can not find [%s] in Output", name)); + return it->second.size(); } const Variable* InputVar(const std::string& name) const override { diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 754b553bd1..0e75775e91 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -31,6 +31,7 @@ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/selected_rows.h" DECLARE_bool(use_mkldnn); diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 1a297e7238..a45d32b34b 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -90,7 +90,7 @@ class TransposeOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = ctx.Attr("data_format"); + auto &data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/phi/core/attribute.h b/paddle/phi/core/attribute.h new file mode 100644 index 0000000000..d1b2920335 --- /dev/null +++ b/paddle/phi/core/attribute.h @@ -0,0 +1,50 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/utils/variant.h" + +namespace phi { + +class Place; + +// NOTE: Add needed type in the future +using Attribute = paddle::variant, + std::vector, + std::vector, + std::vector, + std::vector, + std::vector, + Scalar, + std::vector, + IntArray, + DataType, + DataLayout, + Place>; + +} // namespace phi diff --git a/paddle/phi/core/infermeta_utils.cc b/paddle/phi/core/infermeta_utils.cc index 70f26102cb..8bdad9d6d2 100644 --- a/paddle/phi/core/infermeta_utils.cc +++ b/paddle/phi/core/infermeta_utils.cc @@ -30,7 +30,7 @@ void InferMetaContext::EmplaceBackOutput(MetaTensor output) { outputs_.emplace_back(std::move(output)); output_range_.emplace_back(std::pair(index, index + 1)); } -void InferMetaContext::EmplaceBackAttr(paddle::any attr) { +void InferMetaContext::EmplaceBackAttr(Attribute attr) { attrs_.emplace_back(std::move(attr)); } @@ -120,6 +120,38 @@ std::vector InferMetaContext::MutableOutputBetween(size_t start, return result; } +template +const AttrType& InferMetaContext::AttrAt(size_t idx) const { + try { + return paddle::get(attrs_.at(idx)); + } catch (paddle::bad_variant_access const& e) { + PADDLE_THROW(phi::errors::InvalidArgument( + "Attribute cast error in InferMeta Context, the expected attribute " + "type is `%s`.", + std::type_index(typeid(AttrType)).name())); + } +} + +template const bool& InferMetaContext::AttrAt(size_t idx) const; +template const int& InferMetaContext::AttrAt(size_t idx) const; +template const int64_t& InferMetaContext::AttrAt(size_t idx) const; +template const float& InferMetaContext::AttrAt(size_t idx) const; +template const double& InferMetaContext::AttrAt(size_t idx) const; +template const std::string& InferMetaContext::AttrAt(size_t idx) const; +template const std::vector& InferMetaContext::AttrAt(size_t idx) const; +template const std::vector& InferMetaContext::AttrAt(size_t idx) const; +template const std::vector& InferMetaContext::AttrAt(size_t idx) const; +template const std::vector& InferMetaContext::AttrAt(size_t idx) const; +template const std::vector& InferMetaContext::AttrAt(size_t idx) const; +template const std::vector& InferMetaContext::AttrAt( + size_t idx) const; +template const Scalar& InferMetaContext::AttrAt(size_t idx) const; +template const std::vector& InferMetaContext::AttrAt(size_t idx) const; +template const IntArray& InferMetaContext::AttrAt(size_t idx) const; +template const DataType& InferMetaContext::AttrAt(size_t idx) const; +template const DataLayout& InferMetaContext::AttrAt(size_t idx) const; +template const Place& InferMetaContext::AttrAt(size_t idx) const; + MetaFnFactory& MetaFnFactory::Instance() { static MetaFnFactory g_meta_fn_map; return g_meta_fn_map; diff --git a/paddle/phi/core/infermeta_utils.h b/paddle/phi/core/infermeta_utils.h index 699c38ebd4..8c726bffa2 100644 --- a/paddle/phi/core/infermeta_utils.h +++ b/paddle/phi/core/infermeta_utils.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/attribute.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/macros.h" #include "paddle/phi/core/meta_tensor.h" @@ -41,7 +42,7 @@ class InferMetaContext { void EmplaceBackInput(MetaTensor input); void EmplaceBackOutput(MetaTensor output); - void EmplaceBackAttr(paddle::any attr); + void EmplaceBackAttr(Attribute attr); void EmplaceBackInputs( paddle::SmallVector inputs); @@ -61,17 +62,7 @@ class InferMetaContext { size_t end); template - AttrType AttrAt(size_t idx) { - try { - return paddle::any_cast(attrs_.at(idx)); - } catch (paddle::bad_any_cast& e) { - PADDLE_THROW(phi::errors::InvalidArgument( - "Attribute cast error in InferMeta Context, the expected attribute " - "type is `%s`, but actual attribute type is `%s`.", - std::type_index(typeid(AttrType)).name(), - std::type_index(attrs_.at(idx).type()).name())); - } - } + const AttrType& AttrAt(size_t idx) const; const std::pair& InputRangeAt(size_t idx) const; const std::pair& OutputRangeAt(size_t idx) const; @@ -81,7 +72,7 @@ class InferMetaContext { protected: MetaConfig config_; - paddle::SmallVector attrs_; + paddle::SmallVector attrs_; paddle::SmallVector, phi::kInputSmallVectorSize> input_range_; @@ -111,6 +102,21 @@ class InferMetaContext { } \ } +#define PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(attr_type) \ + template \ + struct InferMetaFnCallHelper { \ + template \ + static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { \ + static_assert(out_idx == 0, \ + "InferMeta's Attributes should appear before Outputs."); \ + const attr_type& arg = ctx->AttrAt(attr_idx); \ + InferMetaFnCallHelper< \ + Tail...>::template Call(ctx, \ + pargs..., \ + arg); \ + } \ + } + template struct InferMetaTypeTag {}; @@ -201,27 +207,27 @@ struct InferMetaFnImpl { } }; - // TODO(chenweihang): support other attr type later PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::string&); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector&); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector&); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE( - const std::vector&); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector&); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector&); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE( - const std::vector&); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const Scalar&); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const IntArray&); - - // TODO(chenweihang): support vector input later + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(Scalar); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(IntArray); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); template struct InferMetaFnCallHelper { diff --git a/paddle/phi/core/kernel_context.cc b/paddle/phi/core/kernel_context.cc index cf862cbde1..9935a5bf5c 100644 --- a/paddle/phi/core/kernel_context.cc +++ b/paddle/phi/core/kernel_context.cc @@ -73,7 +73,7 @@ void KernelContext::EmplaceBackOutputsWithoutSetRange( std::make_move_iterator(outputs.end())); } -void KernelContext::EmplaceBackAttr(paddle::any attr) { +void KernelContext::EmplaceBackAttr(Attribute attr) { attrs_.emplace_back(std::move(attr)); } @@ -113,4 +113,34 @@ const std::pair& KernelContext::OutputRangeAt(size_t idx) const { return output_range_.at(idx); } +template +const AttrType& KernelContext::AttrAt(size_t idx) const { + try { + return paddle::get(attrs_.at(idx)); + } catch (paddle::bad_variant_access const& ex) { + PADDLE_THROW(phi::errors::InvalidArgument( + "Attribute cast error in Op Kernel Context.")); + } +} + +template const bool& KernelContext::AttrAt(size_t idx) const; +template const int& KernelContext::AttrAt(size_t idx) const; +template const int64_t& KernelContext::AttrAt(size_t idx) const; +template const float& KernelContext::AttrAt(size_t idx) const; +template const double& KernelContext::AttrAt(size_t idx) const; +template const std::string& KernelContext::AttrAt(size_t idx) const; +template const std::vector& KernelContext::AttrAt(size_t idx) const; +template const std::vector& KernelContext::AttrAt(size_t idx) const; +template const std::vector& KernelContext::AttrAt(size_t idx) const; +template const std::vector& KernelContext::AttrAt(size_t idx) const; +template const std::vector& KernelContext::AttrAt(size_t idx) const; +template const std::vector& KernelContext::AttrAt( + size_t idx) const; +template const Scalar& KernelContext::AttrAt(size_t idx) const; +template const std::vector& KernelContext::AttrAt(size_t idx) const; +template const IntArray& KernelContext::AttrAt(size_t idx) const; +template const DataType& KernelContext::AttrAt(size_t idx) const; +template const DataLayout& KernelContext::AttrAt(size_t idx) const; +template const Place& KernelContext::AttrAt(size_t idx) const; + } // namespace phi diff --git a/paddle/phi/core/kernel_context.h b/paddle/phi/core/kernel_context.h index ab4e044e62..a06efb573a 100644 --- a/paddle/phi/core/kernel_context.h +++ b/paddle/phi/core/kernel_context.h @@ -17,11 +17,12 @@ #include #include +#include "paddle/phi/core/attribute.h" #include "paddle/phi/core/device_context.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_utils.h" -#include "paddle/utils/any.h" +#include "paddle/phi/core/type_defs.h" #include "paddle/utils/optional.h" #include "paddle/utils/small_vector.h" @@ -64,7 +65,7 @@ class KernelContext { void EmplaceBackOutputsWithoutSetRange( paddle::SmallVector outputs); - void EmplaceBackAttr(paddle::any attr); + void EmplaceBackAttr(Attribute attr); const std::pair& InputRangeAt(size_t idx) const; @@ -128,14 +129,7 @@ class KernelContext { } template - AttrType AttrAt(size_t idx) const { - try { - return paddle::any_cast(attrs_.at(idx)); - } catch (paddle::bad_any_cast&) { - PADDLE_THROW(phi::errors::InvalidArgument( - "Attribute cast error in Op Kernel Context.")); - } - } + const AttrType& AttrAt(size_t idx) const; size_t InputsSize() const { return inputs_.size(); } size_t OutputsSize() const { return outputs_.size(); } @@ -146,10 +140,11 @@ class KernelContext { paddle::SmallVector inputs_; paddle::SmallVector outputs_; - paddle::SmallVector attrs_; + paddle::SmallVector attrs_; - paddle::SmallVector> input_range_; - paddle::SmallVector> output_range_; + paddle::SmallVector, kInputSmallVectorSize> input_range_; + paddle::SmallVector, kOutputSmallVectorSize> + output_range_; }; } // namespace phi diff --git a/paddle/phi/core/kernel_registry.h b/paddle/phi/core/kernel_registry.h index b18fd9e05f..356ab58f40 100644 --- a/paddle/phi/core/kernel_registry.h +++ b/paddle/phi/core/kernel_registry.h @@ -105,6 +105,11 @@ struct KernelArgsParseFunctor { default_tensor_layout, default_key.dtype(), arg_type); + } else if (arg_type == std::type_index(typeid(const StringTensor&))) { + args_def->AppendInput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else if (arg_type == std::type_index(typeid(const SparseCooTensor&))) { args_def->AppendInput(default_key.backend(), default_tensor_layout, @@ -153,6 +158,11 @@ struct KernelArgsParseFunctor { default_tensor_layout, default_key.dtype(), arg_type); + } else if (arg_type == std::type_index(typeid(StringTensor*))) { + args_def->AppendOutput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else { // Attribute deal with // TODO(chenweihang): now here allow any types of attribute, maybe diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index 55574ea03a..ddc58f512b 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -168,6 +168,24 @@ namespace phi { } \ } +#define PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(attr_type) \ + template \ + struct KernelCallHelper { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + static_assert(out_idx == 0, \ + "Kernel's Attributes should appear before Outputs."); \ + const attr_type& arg = ctx->AttrAt(attr_idx); \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., arg); \ + } \ + } + #define PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(tensor_type) \ template \ struct KernelCallHelper { \ @@ -270,19 +288,20 @@ struct KernelImpl { PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int64_t); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(phi::dtype::float16); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataType); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataLayout); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(Place); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const IntArray&); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::string&); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); + PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string); + PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(Scalar); + PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(IntArray); + PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); + PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); + PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); + PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); + PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); + PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); + PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); /* Output Helpers */ diff --git a/paddle/phi/core/type_defs.h b/paddle/phi/core/type_defs.h index a1e7836088..0af1c0af23 100644 --- a/paddle/phi/core/type_defs.h +++ b/paddle/phi/core/type_defs.h @@ -15,6 +15,8 @@ #pragma once #include +#include +#include namespace phi { diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 24d13bcc4b..7514f19ef4 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -228,13 +228,6 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) { out->set_dtype(x.dtype()); } -void CopyToInferMeta(const MetaTensor& x, - Backend backend, - bool blocking, - MetaTensor* out) { - UnchangedInferMeta(x, out); -} - void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) { out->set_dims(x.dims()); out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype); @@ -3002,5 +2995,5 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) { } // namespace phi -PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); +PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta); PD_REGISTER_INFER_META_FN(split, phi::SplitInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index ac5040388b..70b868eeb5 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -58,11 +58,6 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out); -void CopyToInferMeta(const MetaTensor& x, - Backend backend, - bool blocking, - MetaTensor* out); - void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); void CumsumInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/cpu/where_grad_kernel.cc b/paddle/phi/kernels/cpu/where_grad_kernel.cc index 67c8cee103..a9cdbd7ad7 100644 --- a/paddle/phi/kernels/cpu/where_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/where_grad_kernel.cc @@ -14,6 +14,8 @@ #include "paddle/phi/kernels/where_grad_kernel.h" +#include "paddle/phi/core/kernel_registry.h" + namespace phi { template diff --git a/paddle/phi/kernels/cpu/where_kernel.cc b/paddle/phi/kernels/cpu/where_kernel.cc index f624c13c26..353d11c93c 100644 --- a/paddle/phi/kernels/cpu/where_kernel.cc +++ b/paddle/phi/kernels/cpu/where_kernel.cc @@ -14,6 +14,8 @@ #include "paddle/phi/kernels/where_kernel.h" +#include "paddle/phi/core/kernel_registry.h" + namespace phi { template diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 84da69ed5d..b75477a1af 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include #include #include @@ -33,7 +34,6 @@ #include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" -#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/extensions.h" diff --git a/paddle/phi/kernels/gpu/where_grad_kernel.cu b/paddle/phi/kernels/gpu/where_grad_kernel.cu index f21aca80e2..14cc1d3113 100644 --- a/paddle/phi/kernels/gpu/where_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/where_grad_kernel.cu @@ -14,6 +14,9 @@ #include "paddle/phi/kernels/where_grad_kernel.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" + namespace phi { template diff --git a/paddle/phi/kernels/gpu/where_kernel.cu b/paddle/phi/kernels/gpu/where_kernel.cu index 03c24eea3a..a0be388065 100644 --- a/paddle/phi/kernels/gpu/where_kernel.cu +++ b/paddle/phi/kernels/gpu/where_kernel.cu @@ -14,6 +14,8 @@ #include "paddle/phi/kernels/where_kernel.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" diff --git a/paddle/phi/kernels/where_grad_kernel.h b/paddle/phi/kernels/where_grad_kernel.h index 1a3c66ee6e..5f596da93e 100644 --- a/paddle/phi/kernels/where_grad_kernel.h +++ b/paddle/phi/kernels/where_grad_kernel.h @@ -14,10 +14,7 @@ #pragma once -#include "paddle/phi/backends/all_context.h" -#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/kernel_registry.h" namespace phi { diff --git a/paddle/phi/kernels/where_kernel.h b/paddle/phi/kernels/where_kernel.h index 254271ac9c..6348177e69 100644 --- a/paddle/phi/kernels/where_kernel.h +++ b/paddle/phi/kernels/where_kernel.h @@ -14,10 +14,7 @@ #pragma once -#include "paddle/phi/backends/all_context.h" -#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/kernel_registry.h" namespace phi { diff --git a/paddle/phi/tests/core/test_custom_kernel.cc b/paddle/phi/tests/core/test_custom_kernel.cc index 07530f70b7..2a5b8ec8fa 100644 --- a/paddle/phi/tests/core/test_custom_kernel.cc +++ b/paddle/phi/tests/core/test_custom_kernel.cc @@ -49,7 +49,6 @@ void FakeDot(const Context& dev_ctx, float fake_attr_float, double fake_attr_double, int64_t fake_attr_int64, - phi::dtype::float16 fake_attr_f16, phi::DataType fake_attr_dtype, const phi::Scalar& fake_attr_scalar, const phi::IntArray& fake_attr_int_array, @@ -64,7 +63,6 @@ void FakeDot(const Context& dev_ctx, std::cout << "fake_attr_float: " << fake_attr_float << std::endl; std::cout << "fake_attr_double: " << fake_attr_double << std::endl; std::cout << "fake_attr_int64: " << fake_attr_int64 << std::endl; - std::cout << "fake_attr_f16: " << fake_attr_f16 << std::endl; std::cout << "fake_attr_dtype: " << fake_attr_dtype << std::endl; std::cout << "fake_attr_int64_vec: " << fake_attr_int64_vec.size() << std::endl; @@ -78,7 +76,6 @@ void FakeDot(const Context& dev_ctx, assert(fake_attr_float == 2); assert(fake_attr_double == 3); assert(fake_attr_int64 == 4); - assert(fake_attr_f16 == phi::dtype::float16(5)); assert(fake_attr_dtype == phi::DataType::UINT32); assert(fake_attr_int64_vec.size() == 0); assert(fake_attr_int_vec.size() == 0); @@ -248,7 +245,6 @@ TEST(CustomKernel, custom_kernel_dot) { float fake_attr_float = 2.0; double fake_attr_double = 3.0; int64_t fake_attr_int64 = 4; - phi::dtype::float16 fake_attr_f16 = phi::dtype::float16(5); phi::DataType fake_attr_dtype = phi::DataType::UINT32; paddle::framework::LoDTensor tmp_tensor; tmp_tensor.mutable_data({1}, phi::TransToPhiPlace(backend)); @@ -262,7 +258,6 @@ TEST(CustomKernel, custom_kernel_dot) { kernel_context.EmplaceBackAttr(fake_attr_float); kernel_context.EmplaceBackAttr(fake_attr_double); kernel_context.EmplaceBackAttr(fake_attr_int64); - kernel_context.EmplaceBackAttr(fake_attr_f16); kernel_context.EmplaceBackAttr(fake_attr_dtype); kernel_context.EmplaceBackAttr(fake_attr_scalar); kernel_context.EmplaceBackAttr(fake_attr_int_array); diff --git a/paddle/phi/tests/core/test_meta_fn_utils.cc b/paddle/phi/tests/core/test_meta_fn_utils.cc index 028b9d2335..07832494d5 100644 --- a/paddle/phi/tests/core/test_meta_fn_utils.cc +++ b/paddle/phi/tests/core/test_meta_fn_utils.cc @@ -60,32 +60,6 @@ TEST(MetaFnFactory, InferMetaFnExists) { EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]); } -TEST(MetaFnFactory, CopyInferMetaFn) { - phi::DenseTensor dense_x; - dense_x.Resize({3, 4}); - - phi::MetaTensor meta_x(&dense_x); - phi::DenseTensor dense_out1; - phi::MetaTensor meta_out(&dense_out1); - phi::UnchangedInferMeta(meta_x, &meta_out); - - auto shared_meat_x = phi::MetaTensor(&dense_x); - phi::DenseTensor dense_out2; - auto shared_meta_out = phi::MetaTensor(&dense_out2); - - phi::InferMetaContext ctx; - ctx.EmplaceBackInput(shared_meat_x); - ctx.EmplaceBackAttr(Backend::CPU); - ctx.EmplaceBackAttr(false); - ctx.EmplaceBackOutput(shared_meta_out); - ctx.SetMetaConfig({/*is_runtime =*/true, /*is_run_mkldnn_kernel=*/false}); - phi::MetaFnFactory::Instance().Get("copy_to")(&ctx); - - EXPECT_EQ(dense_out1.dims().size(), dense_out2.dims().size()); - EXPECT_EQ(dense_out1.dims()[0], dense_out2.dims()[0]); - EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]); -} - TEST(MetaFnFactory, SplitInferMetaFn) { phi::DenseTensor dense_x; dense_x.Resize({4, 10}); diff --git a/paddle/utils/variant.h b/paddle/utils/variant.h new file mode 100644 index 0000000000..a7546d094c --- /dev/null +++ b/paddle/utils/variant.h @@ -0,0 +1,2830 @@ +// Copy from +// https://github.com/mpark/variant/blob/single-header/v1.4.0/variant.hpp +// Modify the following points: +// 1. modify namespace mpark to namespace paddle + +// MPark.Variant +// +// Copyright Michael Park, 2015-2017 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE.md or copy at +// http://boost.org/LICENSE_1_0.txt) + +#pragma once + +/* + variant synopsis + +namespace std { + + // 20.7.2, class template variant + template + class variant { + public: + + // 20.7.2.1, constructors + constexpr variant() noexcept(see below); + variant(const variant&); + variant(variant&&) noexcept(see below); + + template constexpr variant(T&&) noexcept(see below); + + template + constexpr explicit variant(in_place_type_t, Args&&...); + + template + constexpr explicit variant( + in_place_type_t, initializer_list, Args&&...); + + template + constexpr explicit variant(in_place_index_t, Args&&...); + + template + constexpr explicit variant( + in_place_index_t, initializer_list, Args&&...); + + // 20.7.2.2, destructor + ~variant(); + + // 20.7.2.3, assignment + variant& operator=(const variant&); + variant& operator=(variant&&) noexcept(see below); + + template variant& operator=(T&&) noexcept(see below); + + // 20.7.2.4, modifiers + template + T& emplace(Args&&...); + + template + T& emplace(initializer_list, Args&&...); + + template + variant_alternative& emplace(Args&&...); + + template + variant_alternative& emplace(initializer_list, Args&&...); + + // 20.7.2.5, value status + constexpr bool valueless_by_exception() const noexcept; + constexpr size_t index() const noexcept; + + // 20.7.2.6, swap + void swap(variant&) noexcept(see below); + }; + + // 20.7.3, variant helper classes + template struct variant_size; // undefined + + template + constexpr size_t variant_size_v = variant_size::value; + + template struct variant_size; + template struct variant_size; + template struct variant_size; + + template + struct variant_size>; + + template struct variant_alternative; // undefined + + template + using variant_alternative_t = typename variant_alternative::type; + + template struct variant_alternative; + template struct variant_alternative; + template struct variant_alternative; + + template + struct variant_alternative>; + + constexpr size_t variant_npos = -1; + + // 20.7.4, value access + template + constexpr bool holds_alternative(const variant&) noexcept; + + template + constexpr variant_alternative_t>& + get(variant&); + + template + constexpr variant_alternative_t>&& + get(variant&&); + + template + constexpr variant_alternative_t> const& + get(const variant&); + + template + constexpr variant_alternative_t> const&& + get(const variant&&); + + template + constexpr T& get(variant&); + + template + constexpr T&& get(variant&&); + + template + constexpr const T& get(const variant&); + + template + constexpr const T&& get(const variant&&); + + template + constexpr add_pointer_t>> + get_if(variant*) noexcept; + + template + constexpr add_pointer_t>> + get_if(const variant*) noexcept; + + template + constexpr add_pointer_t + get_if(variant*) noexcept; + + template + constexpr add_pointer_t + get_if(const variant*) noexcept; + + // 20.7.5, relational operators + template + constexpr bool operator==(const variant&, const variant&); + + template + constexpr bool operator!=(const variant&, const variant&); + + template + constexpr bool operator<(const variant&, const variant&); + + template + constexpr bool operator>(const variant&, const variant&); + + template + constexpr bool operator<=(const variant&, const variant&); + + template + constexpr bool operator>=(const variant&, const variant&); + + // 20.7.6, visitation + template + constexpr see below visit(Visitor&&, Variants&&...); + + // 20.7.7, class monostate + struct monostate; + + // 20.7.8, monostate relational operators + constexpr bool operator<(monostate, monostate) noexcept; + constexpr bool operator>(monostate, monostate) noexcept; + constexpr bool operator<=(monostate, monostate) noexcept; + constexpr bool operator>=(monostate, monostate) noexcept; + constexpr bool operator==(monostate, monostate) noexcept; + constexpr bool operator!=(monostate, monostate) noexcept; + + // 20.7.9, specialized algorithms + template + void swap(variant&, variant&) noexcept(see below); + + // 20.7.10, class bad_variant_access + class bad_variant_access; + + // 20.7.11, hash support + template struct hash; + template struct hash>; + template <> struct hash; + +} // namespace std + +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +// MPark.Variant +// +// Copyright Michael Park, 2015-2017 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE.md or copy at +// http://boost.org/LICENSE_1_0.txt) + +#ifndef MPARK_CONFIG_HPP +#define MPARK_CONFIG_HPP + +// MSVC 2015 Update 3. +#if __cplusplus < 201103L && (!defined(_MSC_VER) || _MSC_FULL_VER < 190024210) +#error "MPark.Variant requires C++11 support." +#endif + +#ifndef __has_attribute +#define __has_attribute(x) 0 +#endif + +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif + +#ifndef __has_include +#define __has_include(x) 0 +#endif + +#ifndef __has_feature +#define __has_feature(x) 0 +#endif + +#if __has_attribute(always_inline) || defined(__GNUC__) +#define MPARK_ALWAYS_INLINE __attribute__((__always_inline__)) inline +#elif defined(_MSC_VER) +#define MPARK_ALWAYS_INLINE __forceinline +#else +#define MPARK_ALWAYS_INLINE inline +#endif + +#if __has_builtin(__builtin_addressof) || \ + (defined(__GNUC__) && __GNUC__ >= 7) || defined(_MSC_VER) +#define MPARK_BUILTIN_ADDRESSOF +#endif + +#if __has_builtin(__builtin_unreachable) || defined(__GNUC__) +#define MPARK_BUILTIN_UNREACHABLE __builtin_unreachable() +#elif defined(_MSC_VER) +#define MPARK_BUILTIN_UNREACHABLE __assume(false) +#else +#define MPARK_BUILTIN_UNREACHABLE +#endif + +#if __has_builtin(__type_pack_element) +#define MPARK_TYPE_PACK_ELEMENT +#endif + +#if defined(__cpp_constexpr) && __cpp_constexpr >= 200704 && \ + !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 9) +#define MPARK_CPP11_CONSTEXPR +#endif + +#if defined(__cpp_constexpr) && __cpp_constexpr >= 201304 +#define MPARK_CPP14_CONSTEXPR +#endif + +#if __has_feature(cxx_exceptions) || defined(__cpp_exceptions) || \ + (defined(_MSC_VER) && defined(_CPPUNWIND)) +#define MPARK_EXCEPTIONS +#endif + +#if defined(__cpp_generic_lambdas) || defined(_MSC_VER) +#define MPARK_GENERIC_LAMBDAS +#endif + +#if defined(__cpp_lib_integer_sequence) +#define MPARK_INTEGER_SEQUENCE +#endif + +#if defined(__cpp_return_type_deduction) || defined(_MSC_VER) +#define MPARK_RETURN_TYPE_DEDUCTION +#endif + +#if defined(__cpp_lib_transparent_operators) || defined(_MSC_VER) +#define MPARK_TRANSPARENT_OPERATORS +#endif + +#if defined(__cpp_variable_templates) || defined(_MSC_VER) +#define MPARK_VARIABLE_TEMPLATES +#endif + +#if !defined(__GLIBCXX__) || __has_include() // >= libstdc++-5 +#define MPARK_TRIVIALITY_TYPE_TRAITS +#define MPARK_INCOMPLETE_TYPE_TRAITS +#endif + +#endif // MPARK_CONFIG_HPP + +// MPark.Variant +// +// Copyright Michael Park, 2015-2017 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE.md or copy at +// http://boost.org/LICENSE_1_0.txt) + +#ifndef MPARK_IN_PLACE_HPP +#define MPARK_IN_PLACE_HPP + +#include + +namespace paddle { + +struct in_place_t { + explicit in_place_t() = default; +}; + +template +struct in_place_index_t { + explicit in_place_index_t() = default; +}; + +template +struct in_place_type_t { + explicit in_place_type_t() = default; +}; + +#ifdef MPARK_VARIABLE_TEMPLATES +constexpr in_place_t in_place{}; + +template +constexpr in_place_index_t in_place_index{}; + +template +constexpr in_place_type_t in_place_type{}; +#endif + +} // namespace paddle + +#endif // MPARK_IN_PLACE_HPP + +// MPark.Variant +// +// Copyright Michael Park, 2015-2017 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE.md or copy at +// http://boost.org/LICENSE_1_0.txt) + +#ifndef MPARK_LIB_HPP +#define MPARK_LIB_HPP + +#include +#include +#include +#include + +#define MPARK_RETURN(...) \ + noexcept(noexcept(__VA_ARGS__))->decltype(__VA_ARGS__) { return __VA_ARGS__; } + +namespace paddle { +namespace lib { +template +struct identity { + using type = T; +}; + +inline namespace cpp14 { +template +struct array { + constexpr const T &operator[](std::size_t index) const { return data[index]; } + + T data[N == 0 ? 1 : N]; +}; + +template +using add_pointer_t = typename std::add_pointer::type; + +template +using common_type_t = typename std::common_type::type; + +template +using decay_t = typename std::decay::type; + +template +using enable_if_t = typename std::enable_if::type; + +template +using remove_const_t = typename std::remove_const::type; + +template +using remove_reference_t = typename std::remove_reference::type; + +template +inline constexpr T &&forward(remove_reference_t &t) noexcept { + return static_cast(t); +} + +template +inline constexpr T &&forward(remove_reference_t &&t) noexcept { + static_assert(!std::is_lvalue_reference::value, + "can not forward an rvalue as an lvalue"); + return static_cast(t); +} + +template +inline constexpr remove_reference_t &&move(T &&t) noexcept { + return static_cast &&>(t); +} + +#ifdef MPARK_INTEGER_SEQUENCE +using std::integer_sequence; +using std::index_sequence; +using std::make_index_sequence; +using std::index_sequence_for; +#else +template +struct integer_sequence { + using value_type = T; + static constexpr std::size_t size() noexcept { return sizeof...(Is); } +}; + +template +using index_sequence = integer_sequence; + +template +struct make_index_sequence_concat; + +template +struct make_index_sequence_concat, + index_sequence> + : identity> {}; + +template +struct make_index_sequence_impl; + +template +using make_index_sequence = typename make_index_sequence_impl::type; + +template +struct make_index_sequence_impl + : make_index_sequence_concat, + make_index_sequence> {}; + +template <> +struct make_index_sequence_impl<0> : identity> {}; + +template <> +struct make_index_sequence_impl<1> : identity> {}; + +template +using index_sequence_for = make_index_sequence; +#endif + +// +#ifdef MPARK_TRANSPARENT_OPERATORS +using equal_to = std::equal_to<>; +#else +struct equal_to { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) == lib::forward(rhs)) +}; +#endif + +#ifdef MPARK_TRANSPARENT_OPERATORS +using not_equal_to = std::not_equal_to<>; +#else +struct not_equal_to { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) != lib::forward(rhs)) +}; +#endif + +#ifdef MPARK_TRANSPARENT_OPERATORS +using less = std::less<>; +#else +struct less { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) < lib::forward(rhs)) +}; +#endif + +#ifdef MPARK_TRANSPARENT_OPERATORS +using greater = std::greater<>; +#else +struct greater { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) > lib::forward(rhs)) +}; +#endif + +#ifdef MPARK_TRANSPARENT_OPERATORS +using less_equal = std::less_equal<>; +#else +struct less_equal { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) <= lib::forward(rhs)) +}; +#endif + +#ifdef MPARK_TRANSPARENT_OPERATORS +using greater_equal = std::greater_equal<>; +#else +struct greater_equal { + template + inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const + MPARK_RETURN(lib::forward(lhs) >= lib::forward(rhs)) +}; +#endif +} // namespace cpp14 + +inline namespace cpp17 { +// +template +using bool_constant = std::integral_constant; + +template +struct voider : identity {}; + +template +using void_t = typename voider::type; + +namespace detail { +namespace swappable { + +using std::swap; + +template +struct is_swappable { + private: + template (), std::declval()))> + inline static std::true_type test(int); + + template + inline static std::false_type test(...); + + public: + static constexpr bool value = decltype(test(0))::value; +}; + +template +struct is_nothrow_swappable { + static constexpr bool value = + noexcept(swap(std::declval(), std::declval())); +}; + +template +struct is_nothrow_swappable : std::false_type {}; + +} // namespace swappable +} // namespace detail + +using detail::swappable::is_swappable; + +template +using is_nothrow_swappable = + detail::swappable::is_nothrow_swappable::value, T>; + +// +namespace detail { + +template +struct is_reference_wrapper : std::false_type {}; + +template +struct is_reference_wrapper> : std::true_type {}; + +template +struct Invoke; + +template <> +struct Invoke { + template + inline static constexpr auto invoke(R T::*pmf, Arg &&arg, Args &&... args) + MPARK_RETURN((lib::forward(arg).*pmf)(lib::forward(args)...)) +}; + +template <> +struct Invoke { + template + inline static constexpr auto invoke(R T::*pmf, Arg &&arg, Args &&... args) + MPARK_RETURN((lib::forward(arg).get().* + pmf)(lib::forward(args)...)) +}; + +template <> +struct Invoke { + template + inline static constexpr auto invoke(R T::*pmf, Arg &&arg, Args &&... args) + MPARK_RETURN(((*lib::forward(arg)).* + pmf)(lib::forward(args)...)) +}; + +template <> +struct Invoke { + template + inline static constexpr auto invoke(R T::*pmo, Arg &&arg) + MPARK_RETURN(lib::forward(arg).*pmo) +}; + +template <> +struct Invoke { + template + inline static constexpr auto invoke(R T::*pmo, Arg &&arg) + MPARK_RETURN(lib::forward(arg).get().*pmo) +}; + +template <> +struct Invoke { + template + inline static constexpr auto invoke(R T::*pmo, Arg &&arg) + MPARK_RETURN((*lib::forward(arg)).*pmo) +}; + +template +inline constexpr auto invoke(R T::*f, Arg &&arg, Args &&... args) MPARK_RETURN( + Invoke::value, + (std::is_base_of>::value + ? 0 + : is_reference_wrapper>::value ? 1 : 2)>:: + invoke(f, lib::forward(arg), lib::forward(args)...)) + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + template + inline constexpr auto invoke(F &&f, Args &&... args) + MPARK_RETURN(lib::forward(f)(lib::forward(args)...)) +#ifdef _MSC_VER +#pragma warning(pop) +#endif +} // namespace detail + +template +inline constexpr auto invoke(F &&f, Args &&... args) + MPARK_RETURN(detail::invoke(lib::forward(f), + lib::forward(args)...)) + + namespace detail { + template + struct invoke_result {}; + + template + struct invoke_result< + void_t(), std::declval()...))>, + F, + Args...> : identity(), + std::declval()...))> {}; + +} // namespace detail + +template +using invoke_result = detail::invoke_result; + +template +using invoke_result_t = typename invoke_result::type; + +namespace detail { + +template +struct is_invocable : std::false_type {}; + +template +struct is_invocable>, F, Args...> + : std::true_type {}; + +template +struct is_invocable_r : std::false_type {}; + +template +struct is_invocable_r>, R, F, Args...> + : std::is_convertible, R> {}; + +} // namespace detail + +template +using is_invocable = detail::is_invocable; + +template +using is_invocable_r = detail::is_invocable_r; + +namespace detail { + +template +struct is_nothrow_invocable { + static constexpr bool value = + noexcept(lib::invoke(std::declval(), std::declval()...)); +}; + +template +struct is_nothrow_invocable : std::false_type {}; + +template +struct is_nothrow_invocable_r { + private: + inline static R impl() { + return lib::invoke(std::declval(), std::declval()...); + } + + public: + static constexpr bool value = noexcept(impl()); +}; + +template +struct is_nothrow_invocable_r : std::false_type {}; + +} // namespace detail + +template +using is_nothrow_invocable = + detail::is_nothrow_invocable::value, F, Args...>; + +template +using is_nothrow_invocable_r = detail:: + is_nothrow_invocable_r::value, R, F, Args...>; + +// +#ifdef MPARK_BUILTIN_ADDRESSOF +template +inline constexpr T *addressof(T &arg) noexcept { + return __builtin_addressof(arg); +} +#else +namespace detail { + +namespace has_addressof_impl { + +struct fail; + +template +inline fail operator&(T &&); + +template +inline static constexpr bool impl() { + return (std::is_class::value || std::is_union::value) && + !std::is_same()), fail>::value; +} + +} // namespace has_addressof_impl + +template +using has_addressof = bool_constant()>; + +template +inline constexpr T *addressof(T &arg, std::true_type) noexcept { + return std::addressof(arg); +} + +template +inline constexpr T *addressof(T &arg, std::false_type) noexcept { + return &arg; +} + +} // namespace detail + +template +inline constexpr T *addressof(T &arg) noexcept { + return detail::addressof(arg, detail::has_addressof{}); +} +#endif + +template +inline constexpr T *addressof(const T &&) = delete; + +} // namespace cpp17 + +template +struct remove_all_extents : identity {}; + +template +struct remove_all_extents> : remove_all_extents {}; + +template +using remove_all_extents_t = typename remove_all_extents::type; + +template +using size_constant = std::integral_constant; + +template +struct indexed_type : size_constant { + using type = T; +}; + +template +using all = std::is_same, + integer_sequence>; + +#ifdef MPARK_TYPE_PACK_ELEMENT +template +using type_pack_element_t = __type_pack_element; +#else +template +struct type_pack_element_impl { + private: + template + struct set; + + template + struct set> : indexed_type... {}; + + template + inline static std::enable_if impl(indexed_type); + + inline static std::enable_if impl(...); + + public: + using type = decltype(impl(set>{})); +}; + +template +using type_pack_element = typename type_pack_element_impl::type; + +template +using type_pack_element_t = typename type_pack_element::type; +#endif + +#ifdef MPARK_TRIVIALITY_TYPE_TRAITS +using std::is_trivially_copy_constructible; +using std::is_trivially_move_constructible; +using std::is_trivially_copy_assignable; +using std::is_trivially_move_assignable; +#else +template +struct is_trivially_copy_constructible + : bool_constant::value &&__has_trivial_copy( + T)> {}; + +template +struct is_trivially_move_constructible : bool_constant<__is_trivial(T)> {}; + +template +struct is_trivially_copy_assignable + : bool_constant::value &&__has_trivial_assign( + T)> {}; + +template +struct is_trivially_move_assignable : bool_constant<__is_trivial(T)> {}; +#endif + +template +struct dependent_type : T {}; + +template +struct push_back; + +template +using push_back_t = typename push_back::type; + +template +struct push_back, J> { + using type = index_sequence; +}; + +} // namespace lib +} // namespace paddle + +#undef MPARK_RETURN + +#endif // MPARK_LIB_HPP + +namespace paddle { + +#ifdef MPARK_RETURN_TYPE_DEDUCTION + +#define AUTO auto +#define AUTO_RETURN(...) \ + { return __VA_ARGS__; } + +#define AUTO_REFREF auto && +#define AUTO_REFREF_RETURN(...) \ + { return __VA_ARGS__; } + +#define DECLTYPE_AUTO decltype(auto) +#define DECLTYPE_AUTO_RETURN(...) \ + { return __VA_ARGS__; } + +#else + +#define AUTO auto +#define AUTO_RETURN(...) \ + ->lib::decay_t { return __VA_ARGS__; } + +#define AUTO_REFREF auto +#define AUTO_REFREF_RETURN(...) \ + ->decltype((__VA_ARGS__)) { \ + static_assert(std::is_reference::value, ""); \ + return __VA_ARGS__; \ + } + +#define DECLTYPE_AUTO auto +#define DECLTYPE_AUTO_RETURN(...) \ + ->decltype(__VA_ARGS__) { return __VA_ARGS__; } + +#endif + +class bad_variant_access : public std::exception { + public: + virtual const char *what() const noexcept override { + return "bad_variant_access"; + } +}; + +[[noreturn]] inline void throw_bad_variant_access() { +#ifdef MPARK_EXCEPTIONS + throw bad_variant_access{}; +#else + std::terminate(); + MPARK_BUILTIN_UNREACHABLE; +#endif +} + +template +class variant; + +template +struct variant_size; + +#ifdef MPARK_VARIABLE_TEMPLATES +template +constexpr std::size_t variant_size_v = variant_size::value; +#endif + +template +struct variant_size : variant_size {}; + +template +struct variant_size : variant_size {}; + +template +struct variant_size : variant_size {}; + +template +struct variant_size> : lib::size_constant {}; + +template +struct variant_alternative; + +template +using variant_alternative_t = typename variant_alternative::type; + +template +struct variant_alternative + : std::add_const> {}; + +template +struct variant_alternative + : std::add_volatile> {}; + +template +struct variant_alternative + : std::add_cv> {}; + +template +struct variant_alternative> { + static_assert(I < sizeof...(Ts), + "index out of bounds in `std::variant_alternative<>`"); + using type = lib::type_pack_element_t; +}; + +constexpr std::size_t variant_npos = static_cast(-1); + +namespace detail { + +constexpr std::size_t not_found = static_cast(-1); +constexpr std::size_t ambiguous = static_cast(-2); + +#ifdef MPARK_CPP14_CONSTEXPR +template +inline constexpr std::size_t find_index() { + constexpr lib::array matches = { + {std::is_same::value...}}; + std::size_t result = not_found; + for (std::size_t i = 0; i < sizeof...(Ts); ++i) { + if (matches[i]) { + if (result != not_found) { + return ambiguous; + } + result = i; + } + } + return result; +} +#else +inline constexpr std::size_t find_index_impl(std::size_t result, std::size_t) { + return result; +} + +template +inline constexpr std::size_t find_index_impl(std::size_t result, + std::size_t idx, + bool b, + Bs... bs) { + return b ? (result != not_found ? ambiguous + : find_index_impl(idx, idx + 1, bs...)) + : find_index_impl(result, idx + 1, bs...); +} + +template +inline constexpr std::size_t find_index() { + return find_index_impl(not_found, 0, std::is_same::value...); +} +#endif + +template +using find_index_sfinae_impl = + lib::enable_if_t>; + +template +using find_index_sfinae = find_index_sfinae_impl()>; + +template +struct find_index_checked_impl : lib::size_constant { + static_assert(I != not_found, "the specified type is not found."); + static_assert(I != ambiguous, "the specified type is ambiguous."); +}; + +template +using find_index_checked = find_index_checked_impl()>; + +struct valueless_t {}; + +enum class Trait { TriviallyAvailable, Available, Unavailable }; + +template class IsTriviallyAvailable, + template class IsAvailable> +inline constexpr Trait trait() { + return IsTriviallyAvailable::value + ? Trait::TriviallyAvailable + : IsAvailable::value ? Trait::Available : Trait::Unavailable; +} + +#ifdef MPARK_CPP14_CONSTEXPR +template +inline constexpr Trait common_trait(Traits... traits_) { + Trait result = Trait::TriviallyAvailable; + lib::array traits = {{traits_...}}; + for (std::size_t i = 0; i < sizeof...(Traits); ++i) { + Trait t = traits[i]; + if (static_cast(t) > static_cast(result)) { + result = t; + } + } + return result; +} +#else +inline constexpr Trait common_trait_impl(Trait result) { return result; } + +template +inline constexpr Trait common_trait_impl(Trait result, Trait t, Traits... ts) { + return static_cast(t) > static_cast(result) + ? common_trait_impl(t, ts...) + : common_trait_impl(result, ts...); +} + +template +inline constexpr Trait common_trait(Traits... ts) { + return common_trait_impl(Trait::TriviallyAvailable, ts...); +} +#endif + +template +struct traits { + static constexpr Trait copy_constructible_trait = + common_trait(trait()...); + + static constexpr Trait move_constructible_trait = + common_trait(trait()...); + + static constexpr Trait copy_assignable_trait = + common_trait(copy_constructible_trait, + trait()...); + + static constexpr Trait move_assignable_trait = + common_trait(move_constructible_trait, + trait()...); + + static constexpr Trait destructible_trait = common_trait( + trait()...); +}; + +namespace access { + +struct recursive_union { +#ifdef MPARK_RETURN_TYPE_DEDUCTION + template + inline static constexpr auto &&get_alt(V &&v, in_place_index_t<0>) { + return lib::forward(v).head_; + } + + template + inline static constexpr auto &&get_alt(V &&v, in_place_index_t) { + return get_alt(lib::forward(v).tail_, in_place_index_t{}); + } +#else + template + struct get_alt_impl { + template + inline constexpr AUTO_REFREF operator()(V &&v) const + AUTO_REFREF_RETURN(get_alt_impl{}(lib::forward(v).tail_)) + }; + + template + struct get_alt_impl<0, Dummy> { + template + inline constexpr AUTO_REFREF operator()(V &&v) const + AUTO_REFREF_RETURN(lib::forward(v).head_) + }; + + template + inline static constexpr AUTO_REFREF get_alt(V &&v, in_place_index_t) + AUTO_REFREF_RETURN(get_alt_impl{}(lib::forward(v))) +#endif +}; + +struct base { + template + inline static constexpr AUTO_REFREF get_alt(V &&v) +#ifdef _MSC_VER + AUTO_REFREF_RETURN(recursive_union::get_alt(lib::forward(v).data_, + in_place_index_t{})) +#else + AUTO_REFREF_RETURN(recursive_union::get_alt(data(lib::forward(v)), + in_place_index_t{})) +#endif +}; + +struct variant { + template + inline static constexpr AUTO_REFREF get_alt(V &&v) + AUTO_REFREF_RETURN(base::get_alt(lib::forward(v).impl_)) +}; + +} // namespace access + +namespace visitation { + +#if defined(MPARK_CPP14_CONSTEXPR) && !defined(_MSC_VER) +#define MPARK_VARIANT_SWITCH_VISIT +#endif + +struct base { + template + using dispatch_result_t = + decltype(lib::invoke(std::declval(), + access::base::get_alt<0>(std::declval())...)); + + template + struct expected { + template + inline static constexpr bool but_got() { + return std::is_same::value; + } + }; + + template + struct visit_return_type_check { + static_assert(expected::template but_got(), + "`visit` requires the visitor to have a single return type"); + + template + inline static constexpr DECLTYPE_AUTO invoke(Visitor &&visitor, + Alts &&... alts) + DECLTYPE_AUTO_RETURN(lib::invoke(lib::forward(visitor), + lib::forward(alts)...)) + }; + +#ifdef MPARK_VARIANT_SWITCH_VISIT + template + struct dispatcher; + + template + struct dispatcher { + template + MPARK_ALWAYS_INLINE static constexpr R dispatch(F &&, + typename ITs::type &&..., + Vs &&...) { + MPARK_BUILTIN_UNREACHABLE; + } + + template + MPARK_ALWAYS_INLINE static constexpr R dispatch_case(F &&, Vs &&...) { + MPARK_BUILTIN_UNREACHABLE; + } + + template + MPARK_ALWAYS_INLINE static constexpr R dispatch_at(std::size_t, + F &&, + Vs &&...) { + MPARK_BUILTIN_UNREACHABLE; + } + }; + + template + struct dispatcher { + template + MPARK_ALWAYS_INLINE static constexpr R dispatch( + F &&f, typename ITs::type &&... visited_vs) { + using Expected = R; + using Actual = decltype( + lib::invoke(lib::forward(f), + access::base::get_alt( + lib::forward(visited_vs))...)); + return visit_return_type_check::invoke( + lib::forward(f), + access::base::get_alt( + lib::forward(visited_vs))...); + } + + template + MPARK_ALWAYS_INLINE static constexpr R dispatch( + F &&f, typename ITs::type &&... visited_vs, V &&v, Vs &&... vs) { +#define MPARK_DISPATCH(I) \ + dispatcher<(I < lib::decay_t::size()), \ + R, \ + ITs..., \ + lib::indexed_type>:: \ + template dispatch<0>(lib::forward(f), \ + lib::forward(visited_vs)..., \ + lib::forward(v), \ + lib::forward(vs)...) + +#define MPARK_DEFAULT(I) \ + dispatcher<(I < lib::decay_t::size()), R, ITs...>::template dispatch( \ + lib::forward(f), \ + lib::forward(visited_vs)..., \ + lib::forward(v), \ + lib::forward(vs)...) + + switch (v.index()) { + case B + 0: + return MPARK_DISPATCH(B + 0); + case B + 1: + return MPARK_DISPATCH(B + 1); + case B + 2: + return MPARK_DISPATCH(B + 2); + case B + 3: + return MPARK_DISPATCH(B + 3); + case B + 4: + return MPARK_DISPATCH(B + 4); + case B + 5: + return MPARK_DISPATCH(B + 5); + case B + 6: + return MPARK_DISPATCH(B + 6); + case B + 7: + return MPARK_DISPATCH(B + 7); + case B + 8: + return MPARK_DISPATCH(B + 8); + case B + 9: + return MPARK_DISPATCH(B + 9); + case B + 10: + return MPARK_DISPATCH(B + 10); + case B + 11: + return MPARK_DISPATCH(B + 11); + case B + 12: + return MPARK_DISPATCH(B + 12); + case B + 13: + return MPARK_DISPATCH(B + 13); + case B + 14: + return MPARK_DISPATCH(B + 14); + case B + 15: + return MPARK_DISPATCH(B + 15); + case B + 16: + return MPARK_DISPATCH(B + 16); + case B + 17: + return MPARK_DISPATCH(B + 17); + case B + 18: + return MPARK_DISPATCH(B + 18); + case B + 19: + return MPARK_DISPATCH(B + 19); + case B + 20: + return MPARK_DISPATCH(B + 20); + case B + 21: + return MPARK_DISPATCH(B + 21); + case B + 22: + return MPARK_DISPATCH(B + 22); + case B + 23: + return MPARK_DISPATCH(B + 23); + case B + 24: + return MPARK_DISPATCH(B + 24); + case B + 25: + return MPARK_DISPATCH(B + 25); + case B + 26: + return MPARK_DISPATCH(B + 26); + case B + 27: + return MPARK_DISPATCH(B + 27); + case B + 28: + return MPARK_DISPATCH(B + 28); + case B + 29: + return MPARK_DISPATCH(B + 29); + case B + 30: + return MPARK_DISPATCH(B + 30); + case B + 31: + return MPARK_DISPATCH(B + 31); + default: + return MPARK_DEFAULT(B + 32); + } + +#undef MPARK_DEFAULT +#undef MPARK_DISPATCH + } + + template + MPARK_ALWAYS_INLINE static constexpr R dispatch_case(F &&f, Vs &&... vs) { + using Expected = R; + using Actual = decltype( + lib::invoke(lib::forward(f), + access::base::get_alt(lib::forward(vs))...)); + return visit_return_type_check::invoke( + lib::forward(f), + access::base::get_alt(lib::forward(vs))...); + } + + template + MPARK_ALWAYS_INLINE static constexpr R dispatch_at(std::size_t index, + F &&f, + V &&v, + Vs &&... vs) { + static_assert(lib::all<(lib::decay_t::size() == + lib::decay_t::size())...>::value, + "all of the variants must be the same size."); +#define MPARK_DISPATCH_AT(I) \ + dispatcher<(I < lib::decay_t::size()), R>::template dispatch_case( \ + lib::forward(f), lib::forward(v), lib::forward(vs)...) + +#define MPARK_DEFAULT(I) \ + dispatcher<(I < lib::decay_t::size()), R>::template dispatch_at( \ + index, lib::forward(f), lib::forward(v), lib::forward(vs)...) + + switch (index) { + case B + 0: + return MPARK_DISPATCH_AT(B + 0); + case B + 1: + return MPARK_DISPATCH_AT(B + 1); + case B + 2: + return MPARK_DISPATCH_AT(B + 2); + case B + 3: + return MPARK_DISPATCH_AT(B + 3); + case B + 4: + return MPARK_DISPATCH_AT(B + 4); + case B + 5: + return MPARK_DISPATCH_AT(B + 5); + case B + 6: + return MPARK_DISPATCH_AT(B + 6); + case B + 7: + return MPARK_DISPATCH_AT(B + 7); + case B + 8: + return MPARK_DISPATCH_AT(B + 8); + case B + 9: + return MPARK_DISPATCH_AT(B + 9); + case B + 10: + return MPARK_DISPATCH_AT(B + 10); + case B + 11: + return MPARK_DISPATCH_AT(B + 11); + case B + 12: + return MPARK_DISPATCH_AT(B + 12); + case B + 13: + return MPARK_DISPATCH_AT(B + 13); + case B + 14: + return MPARK_DISPATCH_AT(B + 14); + case B + 15: + return MPARK_DISPATCH_AT(B + 15); + case B + 16: + return MPARK_DISPATCH_AT(B + 16); + case B + 17: + return MPARK_DISPATCH_AT(B + 17); + case B + 18: + return MPARK_DISPATCH_AT(B + 18); + case B + 19: + return MPARK_DISPATCH_AT(B + 19); + case B + 20: + return MPARK_DISPATCH_AT(B + 20); + case B + 21: + return MPARK_DISPATCH_AT(B + 21); + case B + 22: + return MPARK_DISPATCH_AT(B + 22); + case B + 23: + return MPARK_DISPATCH_AT(B + 23); + case B + 24: + return MPARK_DISPATCH_AT(B + 24); + case B + 25: + return MPARK_DISPATCH_AT(B + 25); + case B + 26: + return MPARK_DISPATCH_AT(B + 26); + case B + 27: + return MPARK_DISPATCH_AT(B + 27); + case B + 28: + return MPARK_DISPATCH_AT(B + 28); + case B + 29: + return MPARK_DISPATCH_AT(B + 29); + case B + 30: + return MPARK_DISPATCH_AT(B + 30); + case B + 31: + return MPARK_DISPATCH_AT(B + 31); + default: + return MPARK_DEFAULT(B + 32); + } + +#undef MPARK_DEFAULT +#undef MPARK_DISPATCH_AT + } + }; +#else + template + inline static constexpr const T &at(const T &elem) noexcept { + return elem; + } + + template + inline static constexpr const lib::remove_all_extents_t &at( + const lib::array &elems, std::size_t i, Is... is) noexcept { + return at(elems[i], is...); + } + + template + inline static constexpr lib::array, sizeof...(Fs) + 1> + make_farray(F &&f, Fs &&... fs) { + return {{lib::forward(f), lib::forward(fs)...}}; + } + + template + struct make_fmatrix_impl { + template + inline static constexpr dispatch_result_t dispatch(F &&f, + Vs &&... vs) { + using Expected = dispatch_result_t; + using Actual = decltype( + lib::invoke(lib::forward(f), + access::base::get_alt(lib::forward(vs))...)); + return visit_return_type_check::invoke( + lib::forward(f), + access::base::get_alt(lib::forward(vs))...); + } + +#ifdef MPARK_RETURN_TYPE_DEDUCTION + template + inline static constexpr auto impl(lib::index_sequence) { + return &dispatch; + } + + template + inline static constexpr auto impl(Is, + lib::index_sequence, + Ls... ls) { + return make_farray(impl(lib::push_back_t{}, ls...)...); + } +#else + template + struct impl; + + template + struct impl> { + inline constexpr AUTO operator()() const AUTO_RETURN(&dispatch) + }; + + template + struct impl, Ls...> { + inline constexpr AUTO operator()() const + AUTO_RETURN(make_farray(impl, Ls...>{}()...)) + }; +#endif + }; + +#ifdef MPARK_RETURN_TYPE_DEDUCTION + template + inline static constexpr auto make_fmatrix() { + return make_fmatrix_impl::impl( + lib::index_sequence<>{}, + lib::make_index_sequence::size()>{}...); + } +#else + template + inline static constexpr AUTO make_fmatrix() + AUTO_RETURN(typename make_fmatrix_impl::template impl< + lib::index_sequence<>, + lib::make_index_sequence::size()>...>{}()) +#endif + + template + struct make_fdiagonal_impl { + template + inline static constexpr dispatch_result_t dispatch(F &&f, + Vs &&... vs) { + using Expected = dispatch_result_t; + using Actual = decltype( + lib::invoke(lib::forward(f), + access::base::get_alt(lib::forward(vs))...)); + return visit_return_type_check::invoke( + lib::forward(f), + access::base::get_alt(lib::forward(vs))...); + } + + template + inline static constexpr AUTO impl(lib::index_sequence) + AUTO_RETURN(make_farray(&dispatch...)) + }; + + template + inline static constexpr auto make_fdiagonal() + -> decltype(make_fdiagonal_impl::impl( + lib::make_index_sequence::size()>{})) { + static_assert(lib::all<(lib::decay_t::size() == + lib::decay_t::size())...>::value, + "all of the variants must be the same size."); + return make_fdiagonal_impl::impl( + lib::make_index_sequence::size()>{}); + } +#endif +}; + +#if !defined(MPARK_VARIANT_SWITCH_VISIT) && \ + (!defined(_MSC_VER) || _MSC_VER >= 1910) +template +using fmatrix_t = decltype(base::make_fmatrix()); + +template +struct fmatrix { + static constexpr fmatrix_t value = base::make_fmatrix(); +}; + +template +constexpr fmatrix_t fmatrix::value; + +template +using fdiagonal_t = decltype(base::make_fdiagonal()); + +template +struct fdiagonal { + static constexpr fdiagonal_t value = + base::make_fdiagonal(); +}; + +template +constexpr fdiagonal_t fdiagonal::value; +#endif + +struct alt { + template + inline static constexpr DECLTYPE_AUTO visit_alt(Visitor &&visitor, + Vs &&... vs) +#ifdef MPARK_VARIANT_SWITCH_VISIT + DECLTYPE_AUTO_RETURN( + base::dispatcher(vs)))...>>:: + template dispatch<0>(lib::forward(visitor), + as_base(lib::forward(vs))...)) +#elif !defined(_MSC_VER) || _MSC_VER >= 1910 + DECLTYPE_AUTO_RETURN( + base::at(fmatrix(vs)))...>::value, + vs.index()...)(lib::forward(visitor), + as_base(lib::forward(vs))...)) +#else + DECLTYPE_AUTO_RETURN(base::at( + base::make_fmatrix(vs)))...>(), + vs.index()...)(lib::forward(visitor), + as_base(lib::forward(vs))...)) +#endif + + template + inline static constexpr DECLTYPE_AUTO + visit_alt_at(std::size_t index, Visitor &&visitor, Vs &&... vs) +#ifdef MPARK_VARIANT_SWITCH_VISIT + DECLTYPE_AUTO_RETURN( + base::dispatcher< + true, + base::dispatch_result_t< + Visitor, + decltype(as_base(lib::forward(vs)))...>>:: + template dispatch_at<0>(index, + lib::forward(visitor), + as_base(lib::forward(vs))...)) +#elif !defined(_MSC_VER) || _MSC_VER >= 1910 + DECLTYPE_AUTO_RETURN(base::at( + fdiagonal(vs)))...>::value, + index)(lib::forward(visitor), + as_base(lib::forward(vs))...)) +#else + DECLTYPE_AUTO_RETURN( + base::at(base::make_fdiagonal< + Visitor &&, + decltype(as_base(lib::forward(vs)))...>(), + index)(lib::forward(visitor), + as_base(lib::forward(vs))...)) +#endif +}; + +struct variant { + private: + template + struct visitor { + template + inline static constexpr bool does_not_handle() { + return lib::is_invocable::value; + } + }; + + template + struct visit_exhaustiveness_check { + static_assert(visitor::template does_not_handle(), + "`visit` requires the visitor to be exhaustive."); + + inline static constexpr DECLTYPE_AUTO invoke(Visitor &&visitor, + Values &&... values) + DECLTYPE_AUTO_RETURN(lib::invoke(lib::forward(visitor), + lib::forward(values)...)) + }; + + template + struct value_visitor { + Visitor &&visitor_; + + template + inline constexpr DECLTYPE_AUTO operator()(Alts &&... alts) const + DECLTYPE_AUTO_RETURN(visit_exhaustiveness_check< + Visitor, + decltype((lib::forward(alts).value))...>:: + invoke(lib::forward(visitor_), + lib::forward(alts).value...)) + }; + + template + inline static constexpr AUTO make_value_visitor(Visitor &&visitor) + AUTO_RETURN(value_visitor{lib::forward(visitor)}) + + public + : template + inline static constexpr DECLTYPE_AUTO + visit_alt(Visitor &&visitor, Vs &&... vs) + DECLTYPE_AUTO_RETURN(alt::visit_alt(lib::forward(visitor), + lib::forward(vs).impl_...)) + + template + inline static constexpr DECLTYPE_AUTO + visit_alt_at(std::size_t index, Visitor &&visitor, Vs &&... vs) + DECLTYPE_AUTO_RETURN( + alt::visit_alt_at(index, + lib::forward(visitor), + lib::forward(vs).impl_...)) + + template + inline static constexpr DECLTYPE_AUTO + visit_value(Visitor &&visitor, Vs &&... vs) DECLTYPE_AUTO_RETURN( + visit_alt(make_value_visitor(lib::forward(visitor)), + lib::forward(vs)...)) + + template + inline static constexpr DECLTYPE_AUTO + visit_value_at(std::size_t index, Visitor &&visitor, Vs &&... vs) + DECLTYPE_AUTO_RETURN( + visit_alt_at(index, + make_value_visitor(lib::forward(visitor)), + lib::forward(vs)...)) +}; + +} // namespace visitation + +template +struct alt { + using value_type = T; + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4244) +#endif + template + inline explicit constexpr alt(in_place_t, Args &&... args) + : value(lib::forward(args)...) {} +#ifdef _MSC_VER +#pragma warning(pop) +#endif + + T value; +}; + +template +union recursive_union; + +template +union recursive_union {}; + +#define MPARK_VARIANT_RECURSIVE_UNION(destructible_trait, destructor) \ + template \ + union recursive_union { \ + public: \ + inline explicit constexpr recursive_union(valueless_t) noexcept \ + : dummy_{} {} \ + \ + template \ + inline explicit constexpr recursive_union(in_place_index_t<0>, \ + Args &&... args) \ + : head_(in_place_t{}, lib::forward(args)...) {} \ + \ + template \ + inline explicit constexpr recursive_union(in_place_index_t, \ + Args &&... args) \ + : tail_(in_place_index_t{}, lib::forward(args)...) {} \ + \ + recursive_union(const recursive_union &) = default; \ + recursive_union(recursive_union &&) = default; \ + \ + destructor \ + \ + recursive_union & \ + operator=(const recursive_union &) = default; \ + recursive_union &operator=(recursive_union &&) = default; \ + \ + private: \ + char dummy_; \ + alt head_; \ + recursive_union tail_; \ + \ + friend struct access::recursive_union; \ + } + +MPARK_VARIANT_RECURSIVE_UNION(Trait::TriviallyAvailable, + ~recursive_union() = default;); +MPARK_VARIANT_RECURSIVE_UNION(Trait::Available, ~recursive_union(){}); +MPARK_VARIANT_RECURSIVE_UNION(Trait::Unavailable, ~recursive_union() = delete;); + +#undef MPARK_VARIANT_RECURSIVE_UNION + +using index_t = unsigned int; + +template +class base { + public: + inline explicit constexpr base(valueless_t tag) noexcept + : data_(tag), + index_(static_cast(-1)) {} + + template + inline explicit constexpr base(in_place_index_t, Args &&... args) + : data_(in_place_index_t{}, lib::forward(args)...), index_(I) {} + + inline constexpr bool valueless_by_exception() const noexcept { + return index_ == static_cast(-1); + } + + inline constexpr std::size_t index() const noexcept { + return valueless_by_exception() ? variant_npos : index_; + } + + protected: + using data_t = recursive_union; + + friend inline constexpr base &as_base(base &b) { return b; } + friend inline constexpr const base &as_base(const base &b) { return b; } + friend inline constexpr base &&as_base(base &&b) { return lib::move(b); } + friend inline constexpr const base &&as_base(const base &&b) { + return lib::move(b); + } + + friend inline constexpr data_t &data(base &b) { return b.data_; } + friend inline constexpr const data_t &data(const base &b) { return b.data_; } + friend inline constexpr data_t &&data(base &&b) { return lib::move(b).data_; } + friend inline constexpr const data_t &&data(const base &&b) { + return lib::move(b).data_; + } + + inline static constexpr std::size_t size() { return sizeof...(Ts); } + + data_t data_; + index_t index_; + + friend struct access::base; + friend struct visitation::base; +}; + +struct dtor { +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + template + inline void operator()(Alt &alt) const noexcept { + alt.~Alt(); + } +#ifdef _MSC_VER +#pragma warning(pop) +#endif +}; + +#if !defined(_MSC_VER) || _MSC_VER >= 1910 +#define MPARK_INHERITING_CTOR(type, base) using base::base; +#else +#define MPARK_INHERITING_CTOR(type, base) \ + template \ + inline explicit constexpr type(Args &&... args) \ + : base(lib::forward(args)...) {} +#endif + +template +class destructor; + +#define MPARK_VARIANT_DESTRUCTOR(destructible_trait, definition, destroy) \ + template \ + class destructor, destructible_trait> \ + : public base { \ + using super = base; \ + \ + public: \ + MPARK_INHERITING_CTOR(destructor, super) \ + using super::operator=; \ + \ + destructor(const destructor &) = default; \ + destructor(destructor &&) = default; \ + definition destructor &operator=(const destructor &) = default; \ + destructor &operator=(destructor &&) = default; \ + \ + protected: \ + destroy \ + } + +MPARK_VARIANT_DESTRUCTOR(Trait::TriviallyAvailable, ~destructor() = default; + , inline void destroy() noexcept { + this->index_ = static_cast(-1); + }); + +MPARK_VARIANT_DESTRUCTOR(Trait::Available, + ~destructor() { destroy(); }, + inline void destroy() noexcept { + if (!this->valueless_by_exception()) { + visitation::alt::visit_alt(dtor{}, *this); + } + this->index_ = static_cast(-1); + }); + +MPARK_VARIANT_DESTRUCTOR(Trait::Unavailable, ~destructor() = delete; + , inline void destroy() noexcept = delete;); + +#undef MPARK_VARIANT_DESTRUCTOR + +template +class constructor : public destructor { + using super = destructor; + + public: + MPARK_INHERITING_CTOR(constructor, super) + using super::operator=; + + protected: +#ifndef MPARK_GENERIC_LAMBDAS + struct ctor { + template + inline void operator()(LhsAlt &lhs_alt, RhsAlt &&rhs_alt) const { + constructor::construct_alt(lhs_alt, lib::forward(rhs_alt).value); + } + }; +#endif + + template + inline static T &construct_alt(alt &a, Args &&... args) { + auto *result = ::new (static_cast(lib::addressof(a))) + alt(in_place_t{}, lib::forward(args)...); + return result->value; + } + + template + inline static void generic_construct(constructor &lhs, Rhs &&rhs) { + lhs.destroy(); + if (!rhs.valueless_by_exception()) { + visitation::alt::visit_alt_at( + rhs.index(), +#ifdef MPARK_GENERIC_LAMBDAS + [](auto &lhs_alt, auto &&rhs_alt) { + constructor::construct_alt( + lhs_alt, lib::forward(rhs_alt).value); + } +#else + ctor {} +#endif + , + lhs, + lib::forward(rhs)); + lhs.index_ = rhs.index_; + } + } +}; + +template +class move_constructor; + +#define MPARK_VARIANT_MOVE_CONSTRUCTOR(move_constructible_trait, definition) \ + template \ + class move_constructor, move_constructible_trait> \ + : public constructor> { \ + using super = constructor>; \ + \ + public: \ + MPARK_INHERITING_CTOR(move_constructor, super) \ + using super::operator=; \ + \ + move_constructor(const move_constructor &) = default; \ + definition ~move_constructor() = default; \ + move_constructor &operator=(const move_constructor &) = default; \ + move_constructor &operator=(move_constructor &&) = default; \ + } + +MPARK_VARIANT_MOVE_CONSTRUCTOR( + Trait::TriviallyAvailable, + move_constructor(move_constructor &&that) = default;); + +MPARK_VARIANT_MOVE_CONSTRUCTOR( + Trait::Available, + move_constructor(move_constructor &&that) noexcept( + lib::all::value...>::value) + : move_constructor(valueless_t{}) { + this->generic_construct(*this, lib::move(that)); + }); + +MPARK_VARIANT_MOVE_CONSTRUCTOR(Trait::Unavailable, + move_constructor(move_constructor &&) = delete;); + +#undef MPARK_VARIANT_MOVE_CONSTRUCTOR + +template +class copy_constructor; + +#define MPARK_VARIANT_COPY_CONSTRUCTOR(copy_constructible_trait, definition) \ + template \ + class copy_constructor, copy_constructible_trait> \ + : public move_constructor> { \ + using super = move_constructor>; \ + \ + public: \ + MPARK_INHERITING_CTOR(copy_constructor, super) \ + using super::operator=; \ + \ + definition copy_constructor(copy_constructor &&) = default; \ + ~copy_constructor() = default; \ + copy_constructor &operator=(const copy_constructor &) = default; \ + copy_constructor &operator=(copy_constructor &&) = default; \ + } + +MPARK_VARIANT_COPY_CONSTRUCTOR( + Trait::TriviallyAvailable, + copy_constructor(const copy_constructor &that) = default;); + +MPARK_VARIANT_COPY_CONSTRUCTOR(Trait::Available, + copy_constructor(const copy_constructor &that) + : copy_constructor(valueless_t{}) { + this->generic_construct(*this, that); + }); + +MPARK_VARIANT_COPY_CONSTRUCTOR( + Trait::Unavailable, copy_constructor(const copy_constructor &) = delete;); + +#undef MPARK_VARIANT_COPY_CONSTRUCTOR + +template +class assignment : public copy_constructor { + using super = copy_constructor; + + public: + MPARK_INHERITING_CTOR(assignment, super) + using super::operator=; + + template + inline /* auto & */ auto emplace(Args &&... args) + -> decltype(this->construct_alt(access::base::get_alt(*this), + lib::forward(args)...)) { + this->destroy(); + auto &result = this->construct_alt(access::base::get_alt(*this), + lib::forward(args)...); + this->index_ = I; + return result; + } + + protected: +#ifndef MPARK_GENERIC_LAMBDAS + template + struct assigner { + template + inline void operator()(ThisAlt &this_alt, ThatAlt &&that_alt) const { + self->assign_alt(this_alt, lib::forward(that_alt).value); + } + assignment *self; + }; +#endif + + template + inline void assign_alt(alt &a, Arg &&arg) { + if (this->index() == I) { +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4244) +#endif + a.value = lib::forward(arg); +#ifdef _MSC_VER +#pragma warning(pop) +#endif + } else { + struct { + void operator()(std::true_type) const { + this_->emplace(lib::forward(arg_)); + } + void operator()(std::false_type) const { + this_->emplace(T(lib::forward(arg_))); + } + assignment *this_; + Arg &&arg_; + } impl{this, lib::forward(arg)}; + impl(lib::bool_constant < std::is_nothrow_constructible::value || + !std::is_nothrow_move_constructible::value > {}); + } + } + + template + inline void generic_assign(That &&that) { + if (this->valueless_by_exception() && that.valueless_by_exception()) { + // do nothing. + } else if (that.valueless_by_exception()) { + this->destroy(); + } else { + visitation::alt::visit_alt_at( + that.index(), +#ifdef MPARK_GENERIC_LAMBDAS + [this](auto &this_alt, auto &&that_alt) { + this->assign_alt(this_alt, + lib::forward(that_alt).value); + } +#else + assigner { this } +#endif + , + *this, + lib::forward(that)); + } + } +}; + +template +class move_assignment; + +#define MPARK_VARIANT_MOVE_ASSIGNMENT(move_assignable_trait, definition) \ + template \ + class move_assignment, move_assignable_trait> \ + : public assignment> { \ + using super = assignment>; \ + \ + public: \ + MPARK_INHERITING_CTOR(move_assignment, super) \ + using super::operator=; \ + \ + move_assignment(const move_assignment &) = default; \ + move_assignment(move_assignment &&) = default; \ + ~move_assignment() = default; \ + move_assignment &operator=(const move_assignment &) = default; \ + definition \ + } + +MPARK_VARIANT_MOVE_ASSIGNMENT( + Trait::TriviallyAvailable, + move_assignment &operator=(move_assignment &&that) = default;); + +MPARK_VARIANT_MOVE_ASSIGNMENT( + Trait::Available, + move_assignment & + operator=(move_assignment &&that) noexcept( + lib::all<(std::is_nothrow_move_constructible::value && + std::is_nothrow_move_assignable::value)...>::value) { + this->generic_assign(lib::move(that)); + return *this; + }); + +MPARK_VARIANT_MOVE_ASSIGNMENT( + Trait::Unavailable, + move_assignment &operator=(move_assignment &&) = delete;); + +#undef MPARK_VARIANT_MOVE_ASSIGNMENT + +template +class copy_assignment; + +#define MPARK_VARIANT_COPY_ASSIGNMENT(copy_assignable_trait, definition) \ + template \ + class copy_assignment, copy_assignable_trait> \ + : public move_assignment> { \ + using super = move_assignment>; \ + \ + public: \ + MPARK_INHERITING_CTOR(copy_assignment, super) \ + using super::operator=; \ + \ + copy_assignment(const copy_assignment &) = default; \ + copy_assignment(copy_assignment &&) = default; \ + ~copy_assignment() = default; \ + definition copy_assignment &operator=(copy_assignment &&) = default; \ + } + +MPARK_VARIANT_COPY_ASSIGNMENT( + Trait::TriviallyAvailable, + copy_assignment &operator=(const copy_assignment &that) = default;); + +MPARK_VARIANT_COPY_ASSIGNMENT( + Trait::Available, copy_assignment &operator=(const copy_assignment &that) { + this->generic_assign(that); + return *this; + }); + +MPARK_VARIANT_COPY_ASSIGNMENT( + Trait::Unavailable, + copy_assignment &operator=(const copy_assignment &) = delete;); + +#undef MPARK_VARIANT_COPY_ASSIGNMENT + +template +class impl : public copy_assignment> { + using super = copy_assignment>; + + public: + MPARK_INHERITING_CTOR(impl, super) + using super::operator=; + + template + inline void assign(Arg &&arg) { + this->assign_alt(access::base::get_alt(*this), lib::forward(arg)); + } + + inline void swap(impl &that) { + if (this->valueless_by_exception() && that.valueless_by_exception()) { + // do nothing. + } else if (this->index() == that.index()) { + visitation::alt::visit_alt_at(this->index(), +#ifdef MPARK_GENERIC_LAMBDAS + [](auto &this_alt, auto &that_alt) { + using std::swap; + swap(this_alt.value, that_alt.value); + } +#else + swapper {} +#endif + , + *this, + that); + } else { + impl *lhs = this; + impl *rhs = lib::addressof(that); + if (lhs->move_nothrow() && !rhs->move_nothrow()) { + std::swap(lhs, rhs); + } + impl tmp(lib::move(*rhs)); +#ifdef MPARK_EXCEPTIONS + // EXTENSION: When the move construction of `lhs` into `rhs` throws + // and `tmp` is nothrow move constructible then we move `tmp` back + // into `rhs` and provide the strong exception safety guarantee. + try { + this->generic_construct(*rhs, lib::move(*lhs)); + } catch (...) { + if (tmp.move_nothrow()) { + this->generic_construct(*rhs, lib::move(tmp)); + } + throw; + } +#else + this->generic_construct(*rhs, lib::move(*lhs)); +#endif + this->generic_construct(*lhs, lib::move(tmp)); + } + } + + private: +#ifndef MPARK_GENERIC_LAMBDAS + struct swapper { + template + inline void operator()(ThisAlt &this_alt, ThatAlt &that_alt) const { + using std::swap; + swap(this_alt.value, that_alt.value); + } + }; +#endif + + inline constexpr bool move_nothrow() const { + return this->valueless_by_exception() || + lib::array{{std::is_nothrow_move_constructible< + Ts>::value...}}[this->index()]; + } +}; + +#undef MPARK_INHERITING_CTOR + +template +struct overload_leaf { + using F = lib::size_constant (*)(T); + operator F() const { return nullptr; } +}; + +template +struct overload_impl { + private: + template + struct impl; + + template + struct impl> : overload_leaf... {}; + + public: + using type = impl>; +}; + +template +using overload = typename overload_impl::type; + +template +using best_match = lib::invoke_result_t, T &&>; + +template +struct is_in_place_index : std::false_type {}; + +template +struct is_in_place_index> : std::true_type {}; + +template +struct is_in_place_type : std::false_type {}; + +template +struct is_in_place_type> : std::true_type {}; + +} // detail + +template +class variant { + static_assert(0 < sizeof...(Ts), + "variant must consist of at least one alternative."); + + static_assert(lib::all::value...>::value, + "variant can not have an array type as an alternative."); + + static_assert(lib::all::value...>::value, + "variant can not have a reference type as an alternative."); + + static_assert(lib::all::value...>::value, + "variant can not have a void type as an alternative."); + + public: + template < + typename Front = lib::type_pack_element_t<0, Ts...>, + lib::enable_if_t::value, int> = 0> + inline constexpr variant() noexcept( + std::is_nothrow_default_constructible::value) + : impl_(in_place_index_t<0>{}) {} + + variant(const variant &) = default; + variant(variant &&) = default; + + template < + typename Arg, + typename Decayed = lib::decay_t, + lib::enable_if_t::value, int> = 0, + lib::enable_if_t::value, int> = 0, + lib::enable_if_t::value, int> = 0, + std::size_t I = detail::best_match::value, + typename T = lib::type_pack_element_t, + lib::enable_if_t::value, int> = 0> + inline constexpr variant(Arg &&arg) noexcept( + std::is_nothrow_constructible::value) + : impl_(in_place_index_t{}, lib::forward(arg)) {} + + template , + lib::enable_if_t::value, int> = 0> + inline explicit constexpr variant( + in_place_index_t, + Args + &&... args) noexcept(std::is_nothrow_constructible::value) + : impl_(in_place_index_t{}, lib::forward(args)...) {} + + template < + std::size_t I, + typename Up, + typename... Args, + typename T = lib::type_pack_element_t, + lib::enable_if_t< + std::is_constructible &, Args...>::value, + int> = 0> + inline explicit constexpr variant( + in_place_index_t, + std::initializer_list il, + Args &&... args) noexcept(std:: + is_nothrow_constructible< + T, + std::initializer_list &, + Args...>::value) + : impl_(in_place_index_t{}, il, lib::forward(args)...) {} + + template ::value, + lib::enable_if_t::value, int> = 0> + inline explicit constexpr variant( + in_place_type_t, + Args + &&... args) noexcept(std::is_nothrow_constructible::value) + : impl_(in_place_index_t{}, lib::forward(args)...) {} + + template < + typename T, + typename Up, + typename... Args, + std::size_t I = detail::find_index_sfinae::value, + lib::enable_if_t< + std::is_constructible &, Args...>::value, + int> = 0> + inline explicit constexpr variant( + in_place_type_t, + std::initializer_list il, + Args &&... args) noexcept(std:: + is_nothrow_constructible< + T, + std::initializer_list &, + Args...>::value) + : impl_(in_place_index_t{}, il, lib::forward(args)...) {} + + ~variant() = default; + + variant &operator=(const variant &) = default; + variant &operator=(variant &&) = default; + + template , variant>::value, + int> = 0, + std::size_t I = detail::best_match::value, + typename T = lib::type_pack_element_t, + lib::enable_if_t<(std::is_assignable::value && + std::is_constructible::value), + int> = 0> + inline variant &operator=(Arg &&arg) noexcept( + (std::is_nothrow_assignable::value && + std::is_nothrow_constructible::value)) { + impl_.template assign(lib::forward(arg)); + return *this; + } + + template , + lib::enable_if_t::value, int> = 0> + inline T &emplace(Args &&... args) { + return impl_.template emplace(lib::forward(args)...); + } + + template < + std::size_t I, + typename Up, + typename... Args, + typename T = lib::type_pack_element_t, + lib::enable_if_t< + std::is_constructible &, Args...>::value, + int> = 0> + inline T &emplace(std::initializer_list il, Args &&... args) { + return impl_.template emplace(il, lib::forward(args)...); + } + + template ::value, + lib::enable_if_t::value, int> = 0> + inline T &emplace(Args &&... args) { + return impl_.template emplace(lib::forward(args)...); + } + + template < + typename T, + typename Up, + typename... Args, + std::size_t I = detail::find_index_sfinae::value, + lib::enable_if_t< + std::is_constructible &, Args...>::value, + int> = 0> + inline T &emplace(std::initializer_list il, Args &&... args) { + return impl_.template emplace(il, lib::forward(args)...); + } + + inline constexpr bool valueless_by_exception() const noexcept { + return impl_.valueless_by_exception(); + } + + inline constexpr std::size_t index() const noexcept { return impl_.index(); } + + template , + Dummy>::value && + lib::dependent_type, + Dummy>::value)...>::value, + int> = 0> + inline void swap(variant &that) noexcept( + lib::all<(std::is_nothrow_move_constructible::value && + lib::is_nothrow_swappable::value)...>::value) { + impl_.swap(that.impl_); + } + + private: + detail::impl impl_; + + friend struct detail::access::variant; + friend struct detail::visitation::variant; +}; + +template +inline constexpr bool holds_alternative(const variant &v) noexcept { + return v.index() == I; +} + +template +inline constexpr bool holds_alternative(const variant &v) noexcept { + return holds_alternative::value>(v); +} + +namespace detail { +template +struct generic_get_impl { + constexpr generic_get_impl(int) noexcept {} + + constexpr AUTO_REFREF operator()(V &&v) const + AUTO_REFREF_RETURN(access::variant::get_alt(lib::forward(v)).value) +}; + +template +inline constexpr AUTO_REFREF generic_get(V &&v) + AUTO_REFREF_RETURN(generic_get_impl(holds_alternative(v) + ? 0 + : (throw_bad_variant_access(), + 0))(lib::forward(v))) +} // namespace detail + +template +inline constexpr variant_alternative_t> &get( + variant &v) { + return detail::generic_get(v); +} + +template +inline constexpr variant_alternative_t> &&get( + variant &&v) { + return detail::generic_get(lib::move(v)); +} + +template +inline constexpr const variant_alternative_t> &get( + const variant &v) { + return detail::generic_get(v); +} + +template +inline constexpr const variant_alternative_t> &&get( + const variant &&v) { + return detail::generic_get(lib::move(v)); +} + +template +inline constexpr T &get(variant &v) { + return get::value>(v); +} + +template +inline constexpr T &&get(variant &&v) { + return get::value>(lib::move(v)); +} + +template +inline constexpr const T &get(const variant &v) { + return get::value>(v); +} + +template +inline constexpr const T &&get(const variant &&v) { + return get::value>(lib::move(v)); +} + +namespace detail { + +template +inline constexpr /* auto * */ AUTO generic_get_if(V *v) noexcept AUTO_RETURN( + v &&holds_alternative(*v) + ? lib::addressof(access::variant::get_alt(*v).value) + : nullptr) + +} // namespace detail + +template +inline constexpr lib::add_pointer_t>> +get_if(variant *v) noexcept { + return detail::generic_get_if(v); +} + +template +inline constexpr lib::add_pointer_t< + const variant_alternative_t>> +get_if(const variant *v) noexcept { + return detail::generic_get_if(v); +} + +template +inline constexpr lib::add_pointer_t get_if(variant *v) noexcept { + return get_if::value>(v); +} + +template +inline constexpr lib::add_pointer_t get_if( + const variant *v) noexcept { + return get_if::value>(v); +} + +namespace detail { +template +struct convert_to_bool { + template + inline constexpr bool operator()(Lhs &&lhs, Rhs &&rhs) const { + static_assert( + std::is_convertible, bool>::value, + "relational operators must return a type" + " implicitly convertible to bool"); + return lib::invoke(RelOp{}, lib::forward(lhs), lib::forward(rhs)); + } +}; +} // namespace detail + +template +inline constexpr bool operator==(const variant &lhs, + const variant &rhs) { + using detail::visitation::variant; + using equal_to = detail::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (lhs.index() != rhs.index()) return false; + if (lhs.valueless_by_exception()) return true; + return variant::visit_value_at(lhs.index(), equal_to{}, lhs, rhs); +#else + return lhs.index() == rhs.index() && + (lhs.valueless_by_exception() || + variant::visit_value_at(lhs.index(), equal_to{}, lhs, rhs)); +#endif +} + +template +inline constexpr bool operator!=(const variant &lhs, + const variant &rhs) { + using detail::visitation::variant; + using not_equal_to = detail::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (lhs.index() != rhs.index()) return true; + if (lhs.valueless_by_exception()) return false; + return variant::visit_value_at(lhs.index(), not_equal_to{}, lhs, rhs); +#else + return lhs.index() != rhs.index() || + (!lhs.valueless_by_exception() && + variant::visit_value_at(lhs.index(), not_equal_to{}, lhs, rhs)); +#endif +} + +template +inline constexpr bool operator<(const variant &lhs, + const variant &rhs) { + using detail::visitation::variant; + using less = detail::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (rhs.valueless_by_exception()) return false; + if (lhs.valueless_by_exception()) return true; + if (lhs.index() < rhs.index()) return true; + if (lhs.index() > rhs.index()) return false; + return variant::visit_value_at(lhs.index(), less{}, lhs, rhs); +#else + return !rhs.valueless_by_exception() && + (lhs.valueless_by_exception() || lhs.index() < rhs.index() || + (lhs.index() == rhs.index() && + variant::visit_value_at(lhs.index(), less{}, lhs, rhs))); +#endif +} + +template +inline constexpr bool operator>(const variant &lhs, + const variant &rhs) { + using detail::visitation::variant; + using greater = detail::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (lhs.valueless_by_exception()) return false; + if (rhs.valueless_by_exception()) return true; + if (lhs.index() > rhs.index()) return true; + if (lhs.index() < rhs.index()) return false; + return variant::visit_value_at(lhs.index(), greater{}, lhs, rhs); +#else + return !lhs.valueless_by_exception() && + (rhs.valueless_by_exception() || lhs.index() > rhs.index() || + (lhs.index() == rhs.index() && + variant::visit_value_at(lhs.index(), greater{}, lhs, rhs))); +#endif +} + +template +inline constexpr bool operator<=(const variant &lhs, + const variant &rhs) { + using detail::visitation::variant; + using less_equal = detail::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (lhs.valueless_by_exception()) return true; + if (rhs.valueless_by_exception()) return false; + if (lhs.index() < rhs.index()) return true; + if (lhs.index() > rhs.index()) return false; + return variant::visit_value_at(lhs.index(), less_equal{}, lhs, rhs); +#else + return lhs.valueless_by_exception() || + (!rhs.valueless_by_exception() && + (lhs.index() < rhs.index() || + (lhs.index() == rhs.index() && + variant::visit_value_at(lhs.index(), less_equal{}, lhs, rhs)))); +#endif +} + +template +inline constexpr bool operator>=(const variant &lhs, + const variant &rhs) { + using detail::visitation::variant; + using greater_equal = detail::convert_to_bool; +#ifdef MPARK_CPP14_CONSTEXPR + if (rhs.valueless_by_exception()) return true; + if (lhs.valueless_by_exception()) return false; + if (lhs.index() > rhs.index()) return true; + if (lhs.index() < rhs.index()) return false; + return variant::visit_value_at(lhs.index(), greater_equal{}, lhs, rhs); +#else + return rhs.valueless_by_exception() || + (!lhs.valueless_by_exception() && + (lhs.index() > rhs.index() || + (lhs.index() == rhs.index() && + variant::visit_value_at(lhs.index(), greater_equal{}, lhs, rhs)))); +#endif +} + +struct monostate {}; + +inline constexpr bool operator<(monostate, monostate) noexcept { return false; } + +inline constexpr bool operator>(monostate, monostate) noexcept { return false; } + +inline constexpr bool operator<=(monostate, monostate) noexcept { return true; } + +inline constexpr bool operator>=(monostate, monostate) noexcept { return true; } + +inline constexpr bool operator==(monostate, monostate) noexcept { return true; } + +inline constexpr bool operator!=(monostate, monostate) noexcept { + return false; +} + +#ifdef MPARK_CPP14_CONSTEXPR +namespace detail { + +inline constexpr bool all(std::initializer_list bs) { + for (bool b : bs) { + if (!b) { + return false; + } + } + return true; +} + +} // namespace detail + +template +inline constexpr decltype(auto) visit(Visitor &&visitor, Vs &&... vs) { + return (detail::all( + lib::array{!vs.valueless_by_exception()...}) + ? (void)0 + : throw_bad_variant_access()), + detail::visitation::variant::visit_value( + lib::forward(visitor), lib::forward(vs)...); +} +#else +namespace detail { + +template +inline constexpr bool all_impl(const lib::array &bs, std::size_t idx) { + return idx >= N || (bs[idx] && all_impl(bs, idx + 1)); +} + +template +inline constexpr bool all(const lib::array &bs) { + return all_impl(bs, 0); +} + +} // namespace detail + +template +inline constexpr DECLTYPE_AUTO visit(Visitor &&visitor, Vs &&... vs) + DECLTYPE_AUTO_RETURN( + (detail::all(lib::array{ + {!vs.valueless_by_exception()...}}) + ? (void)0 + : throw_bad_variant_access()), + detail::visitation::variant::visit_value(lib::forward(visitor), + lib::forward(vs)...)) +#endif + +template +inline auto swap(variant &lhs, + variant &rhs) noexcept(noexcept(lhs.swap(rhs))) + -> decltype(lhs.swap(rhs)) { + lhs.swap(rhs); +} + +namespace detail { + +template +using enabled_type = T; + +namespace hash { + +template +constexpr bool meets_requirements() noexcept { + return std::is_copy_constructible::value && + std::is_move_constructible::value && + lib::is_invocable_r::value; +} + +template +constexpr bool is_enabled() noexcept { + using H = std::hash; + return meets_requirements() && + std::is_default_constructible::value && + std::is_copy_assignable::value && std::is_move_assignable::value; +} + +} // namespace hash + +} // namespace detail + +#undef AUTO +#undef AUTO_RETURN + +#undef AUTO_REFREF +#undef AUTO_REFREF_RETURN + +#undef DECLTYPE_AUTO +#undef DECLTYPE_AUTO_RETURN + +} // namespace paddle + +namespace std { + +template +struct hash, + paddle::lib::enable_if_t>()...>::value>>> { + using argument_type = paddle::variant; + using result_type = std::size_t; + + inline result_type operator()(const argument_type &v) const { + using paddle::detail::visitation::variant; + std::size_t result = + v.valueless_by_exception() + ? 299792458 // Random value chosen by the universe upon creation + : variant::visit_alt( +#ifdef MPARK_GENERIC_LAMBDAS + [](const auto &alt) { + using alt_type = paddle::lib::decay_t; + using value_type = paddle::lib::remove_const_t< + typename alt_type::value_type>; + return hash{}(alt.value); + } +#else + hasher {} +#endif + , + v); + return hash_combine(result, hash{}(v.index())); + } + + private: +#ifndef MPARK_GENERIC_LAMBDAS + struct hasher { + template + inline std::size_t operator()(const Alt &alt) const { + using alt_type = paddle::lib::decay_t; + using value_type = + paddle::lib::remove_const_t; + return hash{}(alt.value); + } + }; +#endif + + static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) { + return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + } +}; + +template <> +struct hash { + using argument_type = paddle::monostate; + using result_type = std::size_t; + + inline result_type operator()(const argument_type &) const noexcept { + return 66740831; // return a fundamentally attractive random value. + } +}; + +} // namespace std -- GitLab