var_type.h 2.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#pragma once
Y
Yi Wang 已提交
16 17 18 19 20
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/selected_rows.h"
S
sneaxiy 已提交
21
#include "paddle/fluid/framework/var_type_traits.h"
Y
Yi Wang 已提交
22
#include "paddle/fluid/framework/variable.h"
23 24 25

namespace paddle {
namespace framework {
S
sneaxiy 已提交
26 27

template <typename T>
S
sneaxiy 已提交
28 29
inline bool IsType(const std::type_index& type) {
  return type == typeid(T);
S
sneaxiy 已提交
30 31
}

S
sneaxiy 已提交
32 33 34 35 36 37
inline proto::VarType::Type ToVarType(int type) {
  switch (type) {
    case proto::VarType::LOD_TENSOR:
    case proto::VarType::SELECTED_ROWS:
    case proto::VarType::LOD_RANK_TABLE:
    case proto::VarType::LOD_TENSOR_ARRAY:
38
    case proto::VarType::FETCH_LIST:
S
sneaxiy 已提交
39 40 41
    case proto::VarType::READER:
      return static_cast<proto::VarType::Type>(type);
    default:
42 43
      PADDLE_THROW(platform::errors::Unavailable(
          "ToVarType method Unsupported type %d.", type));
44 45 46
  }
}

Y
Yu Yang 已提交
47
template <typename Visitor>
Y
Yancey 已提交
48
inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
S
sneaxiy 已提交
49
  switch (var.Type()) {
S
sneaxiy 已提交
50
    case proto::VarType::LOD_TENSOR:
F
fengjiayi 已提交
51
      visitor(var.Get<LoDTensor>());
Y
Yu Yang 已提交
52
      return;
S
sneaxiy 已提交
53
    case proto::VarType::LOD_RANK_TABLE:
Y
Yu Yang 已提交
54 55
      visitor(var.Get<LoDRankTable>());
      return;
S
sneaxiy 已提交
56
    case proto::VarType::LOD_TENSOR_ARRAY:
Y
Yu Yang 已提交
57 58
      visitor(var.Get<LoDTensorArray>());
      return;
S
sneaxiy 已提交
59
    case proto::VarType::SELECTED_ROWS:
Y
Yu Yang 已提交
60 61
      visitor(var.Get<SelectedRows>());
      return;
S
sneaxiy 已提交
62
    case proto::VarType::READER:
F
fengjiayi 已提交
63 64
      visitor(var.Get<ReaderHolder>());
      return;
65 66 67
    case proto::VarType::FETCH_LIST:
      visitor(var.Get<FetchList>());
      return;
Y
Yu Yang 已提交
68
    default:
69 70
      PADDLE_THROW(platform::errors::Unavailable("Not supported visit type %s.",
                                                 ToTypeName(var.Type())));
Y
Yu Yang 已提交
71 72 73
  }
}

74 75
}  // namespace framework
}  // namespace paddle