data_type.h 3.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yu Yang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yu Yang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yu Yang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yu Yang 已提交
14 15

#pragma once
16
#include <string>
Y
Yu Yang 已提交
17
#include <typeindex>
Y
Yi Wang 已提交
18 19
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h"
20
#include "paddle/fluid/platform/float16.h"
Y
Yu Yang 已提交
21 22 23 24

namespace paddle {
namespace framework {

25
inline proto::VarType::Type ToDataType(std::type_index type) {
26 27
  if (typeid(platform::float16).hash_code() == type.hash_code()) {
    return proto::VarType::FP16;
28 29 30 31 32
  } else if (typeid(const float).hash_code() == type.hash_code()) {
    // CPPLint complains Using C-style cast.  Use static_cast<float>() instead
    // One fix to this is to replace float with const float because
    // typeid(T) == typeid(const T)
    // http://en.cppreference.com/w/cpp/language/typeid
33
    return proto::VarType::FP32;
34
  } else if (typeid(const double).hash_code() == type.hash_code()) {
35
    return proto::VarType::FP64;
36
  } else if (typeid(const int).hash_code() == type.hash_code()) {
37
    return proto::VarType::INT32;
38
  } else if (typeid(const int64_t).hash_code() == type.hash_code()) {
39
    return proto::VarType::INT64;
40
  } else if (typeid(const bool).hash_code() == type.hash_code()) {
41
    return proto::VarType::BOOL;
Y
Yu Yang 已提交
42 43 44 45 46
  } else {
    PADDLE_THROW("Not supported");
  }
}

47
inline std::type_index ToTypeIndex(proto::VarType::Type type) {
Y
Yu Yang 已提交
48
  switch (type) {
49 50
    case proto::VarType::FP16:
      return typeid(platform::float16);
51
    case proto::VarType::FP32:
Y
Yu Yang 已提交
52
      return typeid(float);
53
    case proto::VarType::FP64:
Y
Yu Yang 已提交
54
      return typeid(double);
55
    case proto::VarType::INT32:
Y
Yu Yang 已提交
56
      return typeid(int);
57
    case proto::VarType::INT64:
Y
Yu Yang 已提交
58
      return typeid(int64_t);
59
    case proto::VarType::BOOL:
60
      return typeid(bool);
Y
Yu Yang 已提交
61 62 63 64 65
    default:
      PADDLE_THROW("Not support type %d", type);
  }
}

Y
Yu Yang 已提交
66
template <typename Visitor>
67
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
Y
Yu Yang 已提交
68
  switch (type) {
69 70 71
    case proto::VarType::FP16:
      visitor.template operator()<platform::float16>();
      break;
72
    case proto::VarType::FP32:
Y
Yu Yang 已提交
73 74
      visitor.template operator()<float>();
      break;
75
    case proto::VarType::FP64:
Y
Yu Yang 已提交
76 77
      visitor.template operator()<double>();
      break;
78
    case proto::VarType::INT32:
Y
Yu Yang 已提交
79 80
      visitor.template operator()<int>();
      break;
81
    case proto::VarType::INT64:
Y
Yu Yang 已提交
82 83
      visitor.template operator()<int64_t>();
      break;
84
    case proto::VarType::BOOL:
85 86
      visitor.template operator()<bool>();
      break;
Y
Yu Yang 已提交
87 88 89 90 91
    default:
      PADDLE_THROW("Not supported");
  }
}

92
inline std::string DataTypeToString(const proto::VarType::Type type) {
C
chengduoZH 已提交
93
  switch (type) {
94
    case proto::VarType::FP16:
C
chengduoZH 已提交
95
      return "float16";
96
    case proto::VarType::FP32:
C
chengduoZH 已提交
97
      return "float32";
98
    case proto::VarType::FP64:
C
chengduoZH 已提交
99
      return "float64";
100
    case proto::VarType::INT16:
C
chengduoZH 已提交
101
      return "int16";
102
    case proto::VarType::INT32:
C
chengduoZH 已提交
103
      return "int32";
104
    case proto::VarType::INT64:
C
chengduoZH 已提交
105
      return "int64";
106
    case proto::VarType::BOOL:
C
chengduoZH 已提交
107 108 109 110 111 112 113
      return "bool";
    default:
      PADDLE_THROW("Not support type %d", type);
  }
}

inline std::ostream& operator<<(std::ostream& out,
114
                                const proto::VarType::Type& type) {
C
chengduoZH 已提交
115 116 117 118
  out << DataTypeToString(type);
  return out;
}

Y
Yu Yang 已提交
119 120
}  // namespace framework
}  // namespace paddle