var_type_traits_test.cc 5.2 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>
#include <cstdint>
S
sneaxiy 已提交
17
#include <iostream>
S
sneaxiy 已提交
18
#include <unordered_set>
S
sneaxiy 已提交
19

S
sneaxiy 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#endif
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
#endif

S
sneaxiy 已提交
34 35 36 37 38
namespace paddle {
namespace framework {

template <int kPos, int kEnd, bool kStop>
struct TypeIndexChecker {
S
sneaxiy 已提交
39 40
  template <typename SetType1, typename SetType2>
  static void Check(SetType1 *var_id_set, SetType2 *type_index_set) {
S
sneaxiy 已提交
41 42
    using Type =
        typename std::tuple_element<kPos, VarTypeRegistry::ArgTuple>::type;
S
sneaxiy 已提交
43 44 45
    static_assert(std::is_same<typename VarTypeTrait<Type>::Type, Type>::value,
                  "Type must be the same");
    constexpr auto kId = VarTypeTrait<Type>::kId;
S
sneaxiy 已提交
46 47
    std::type_index actual_type(typeid(Type));
    EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name()));
S
sneaxiy 已提交
48 49 50 51 52
    EXPECT_EQ(VarTraitIdToTypeIndex(kId), actual_type);
    EXPECT_EQ(TypeIndexToVarTraitId(actual_type), kId);
    EXPECT_EQ(VarTraitIdToTypeIndex(TypeIndexToVarTraitId(actual_type)),
              actual_type);
    EXPECT_EQ(TypeIndexToVarTraitId(VarTraitIdToTypeIndex(kId)), kId);
S
sneaxiy 已提交
53 54 55 56 57

    EXPECT_TRUE(var_id_set->count(kId) == 0);              // NOLINT
    EXPECT_TRUE(type_index_set->count(actual_type) == 0);  // NOLINT
    var_id_set->insert(kId);
    type_index_set->insert(std::type_index(typeid(Type)));
S
sneaxiy 已提交
58 59
    TypeIndexChecker<kPos + 1, kEnd, kPos + 1 == kEnd>::Check(var_id_set,
                                                              type_index_set);
S
sneaxiy 已提交
60 61 62 63 64
  }
};

template <int kPos, int kEnd>
struct TypeIndexChecker<kPos, kEnd, true> {
S
sneaxiy 已提交
65 66
  template <typename SetType1, typename SetType2>
  static void Check(SetType1 *, SetType2 *) {}
S
sneaxiy 已提交
67 68
};

S
sneaxiy 已提交
69
TEST(var_type_traits, check_no_duplicate_registry) {
S
sneaxiy 已提交
70
  constexpr size_t kRegisteredNum = VarTypeRegistry::kRegisteredTypeNum;
S
sneaxiy 已提交
71 72 73 74
  std::unordered_set<int> var_id_set;
  std::unordered_set<std::type_index> type_index_set;
  TypeIndexChecker<0, kRegisteredNum, kRegisteredNum == 0>::Check(
      &var_id_set, &type_index_set);
S
sneaxiy 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
}

template <typename T>
bool CheckVarId(int proto_id) {
  static_assert(std::is_same<typename VarTypeTrait<T>::Type, T>::value,
                "Type must be the same");
  return VarTypeTrait<T>::kId == proto_id;
}

TEST(var_type_traits, check_proto_type_id) {
  ASSERT_TRUE(CheckVarId<LoDTensor>(proto::VarType::LOD_TENSOR));
  ASSERT_TRUE(CheckVarId<SelectedRows>(proto::VarType::SELECTED_ROWS));
  ASSERT_TRUE(CheckVarId<std::vector<Scope *>>(proto::VarType::STEP_SCOPES));
  ASSERT_TRUE(CheckVarId<LoDRankTable>(proto::VarType::LOD_RANK_TABLE));
  ASSERT_TRUE(CheckVarId<LoDTensorArray>(proto::VarType::LOD_TENSOR_ARRAY));
  ASSERT_TRUE(CheckVarId<platform::PlaceList>(proto::VarType::PLACE_LIST));
  ASSERT_TRUE(CheckVarId<ReaderHolder>(proto::VarType::READER));
S
sneaxiy 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
  ASSERT_TRUE(CheckVarId<int>(proto::VarType::INT32));
  ASSERT_TRUE(CheckVarId<float>(proto::VarType::FP32));

  ASSERT_EQ(proto::VarType_Type_LOD_TENSOR, proto::VarType::LOD_TENSOR);
  ASSERT_EQ(proto::VarType_Type_SELECTED_ROWS, proto::VarType::SELECTED_ROWS);
  ASSERT_EQ(proto::VarType_Type_STEP_SCOPES, proto::VarType::STEP_SCOPES);
  ASSERT_EQ(proto::VarType_Type_LOD_RANK_TABLE, proto::VarType::LOD_RANK_TABLE);
  ASSERT_EQ(proto::VarType_Type_LOD_TENSOR_ARRAY,
            proto::VarType::LOD_TENSOR_ARRAY);
  ASSERT_EQ(proto::VarType_Type_PLACE_LIST, proto::VarType::PLACE_LIST);
  ASSERT_EQ(proto::VarType_Type_READER, proto::VarType::READER);
  ASSERT_EQ(proto::VarType_Type_FEED_MINIBATCH, proto::VarType::FEED_MINIBATCH);
  ASSERT_EQ(proto::VarType_Type_FETCH_LIST, proto::VarType::FETCH_LIST);
  ASSERT_EQ(proto::VarType_Type_RAW, proto::VarType::RAW);
  ASSERT_EQ(proto::VarType_Type_TUPLE, proto::VarType::TUPLE);
  ASSERT_EQ(proto::VarType_Type_INT32, proto::VarType::INT32);
  ASSERT_EQ(proto::VarType_Type_FP32, proto::VarType::FP32);
S
sneaxiy 已提交
109 110 111
}

TEST(var_type_traits, test_registry) {
S
sneaxiy 已提交
112
  using Registry = detail::VarTypeRegistryImpl<int8_t, int32_t, size_t, double>;
S
sneaxiy 已提交
113 114 115 116 117 118 119 120 121
  ASSERT_TRUE(Registry::TypePos<int8_t>() == 0);
  ASSERT_TRUE(Registry::TypePos<int32_t>() == 1);
  ASSERT_TRUE(Registry::TypePos<size_t>() == 2);
  ASSERT_TRUE(Registry::TypePos<double>() == 3);
  ASSERT_TRUE(Registry::TypePos<float>() == -1);
}

}  // namespace framework
}  // namespace paddle