// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/platform/place.h" #ifdef PADDLE_WITH_CUDA #include #ifndef _WIN32 #include #endif #endif // Users should add forward declarations here namespace paddle { namespace platform { #ifdef PADDLE_WITH_CUDA #ifndef _WIN32 class Communicator; #endif #endif } // namespace platform namespace framework { class Tensor; class LoDTensor; class SelectedRows; class LoDRankTable; class ReaderHolder; class Scope; } // namespace framework namespace operators { template class AlgorithmsCache; class CudnnRNNCache; namespace reader { class LoDTensorBlockingQueueHolder; } // namespace reader } // namespace operators } // namespace paddle namespace paddle { namespace framework { const char *ToTypeName(int var_id); const std::type_index &VarTraitIdToTypeIndex(int var_id); int TypeIndexToVarTraitId(const std::type_index &type); namespace detail { template struct TypePosFinderImpl { static constexpr int kPos = std::is_same::value ? kStart : TypePosFinderImpl::kPos; }; template struct TypePosFinderImpl { static constexpr int kPos = std::is_same::value ? kStart : -1; }; // TypePosFinder helps to find the position in which T is inside Args... // If T is not inside Args..., kPos would be -1 template struct TypePosFinder { static constexpr int kPos = TypePosFinderImpl::kPos; }; template struct VarTypeRegistryImpl { static constexpr size_t kRegisteredTypeNum = sizeof...(Args); using ArgTuple = std::tuple; // TypePos() returns the position in which T is inside Args... // If T is not inside Args..., return -1 template static constexpr int TypePos() { return TypePosFinder::kPos; } // IsRegistered() returns whether T is registered inside RegistryImpl template static constexpr bool IsRegistered() { return TypePos() >= 0; } }; } // namespace detail #define REG_PROTO_VAR_TYPE_TRAIT(type, proto_id) \ template <> \ struct VarTypeTrait { \ static_assert(VarTypeRegistry::IsRegistered(), \ "Must be registered type"); \ using Type = type; \ static constexpr int kId = static_cast(proto_id); \ } /** * The following codes are designed to register variable types. * Only registered types can be stored in Variable. * This registry mechanism is designed to speed up Variable. * * Caution: If you want to add more var types, please consider carefully * whether you really need to add it. */ // Users should add other variable types below. // Paddle would generate unique Ids for each registered variable types. using VarTypeRegistry = detail::VarTypeRegistryImpl< Tensor, LoDTensor, SelectedRows, std::vector, LoDRankTable, LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *, std::map, operators::reader::LoDTensorBlockingQueueHolder, #ifdef PADDLE_WITH_CUDA #ifndef _WIN32 ncclUniqueId, platform::Communicator, #endif operators::AlgorithmsCache, operators::AlgorithmsCache, operators::AlgorithmsCache, operators::CudnnRNNCache, #endif int, float>; template struct VarTypeTrait { static_assert(VarTypeRegistry::IsRegistered(), "Must be registered type"); using Type = T; /** * Unique VarType Id generation. * * The auto-generated id should not be the same as any protobuf id defined in * framework.proto. Therefore, we generate id by adding the type pos and * maximum protobuf id (i.e., proto::VarType::TUPLE). * * However, we may need more protobuf id in the future. * To avoid changing this auto id generation algorithm frequently, we * generate id by adding the type pos and twice of maximum protobuf id (i.e., * proto::VarType::TUPLE). */ static constexpr int kId = VarTypeRegistry::TypePos() + static_cast(proto::VarType::TUPLE) * 2; }; // Users should set some of variable type ids to be what is defined in // framework.proto below REG_PROTO_VAR_TYPE_TRAIT(LoDTensor, proto::VarType::LOD_TENSOR); REG_PROTO_VAR_TYPE_TRAIT(SelectedRows, proto::VarType::SELECTED_ROWS); REG_PROTO_VAR_TYPE_TRAIT(std::vector, proto::VarType::STEP_SCOPES); REG_PROTO_VAR_TYPE_TRAIT(LoDRankTable, proto::VarType::LOD_RANK_TABLE); REG_PROTO_VAR_TYPE_TRAIT(LoDTensorArray, proto::VarType::LOD_TENSOR_ARRAY); REG_PROTO_VAR_TYPE_TRAIT(platform::PlaceList, proto::VarType::PLACE_LIST); REG_PROTO_VAR_TYPE_TRAIT(ReaderHolder, proto::VarType::READER); REG_PROTO_VAR_TYPE_TRAIT(int, proto::VarType::INT32); REG_PROTO_VAR_TYPE_TRAIT(float, proto::VarType::FP32); /** End of variable type registration */ template inline constexpr bool IsRegisteredVarType() { return VarTypeRegistry::IsRegistered(); } #undef REG_PROTO_VAR_TYPE_TRAIT } // namespace framework } // namespace paddle