attribute.cc 3.4 KB
Newer Older
Y
Yi Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/attribute.h"

#include <vector>

namespace paddle {
namespace framework {

F
fengjiayi 已提交
22 23 24 25 26 27 28 29 30
static ProgramDesc* g_program_desc = nullptr;

ProgramDesc& GetProgramDesc() {
  if (g_program_desc == nullptr) {
    g_program_desc = new ProgramDesc();
  }
  return *g_program_desc;
}

D
dangqingqing 已提交
31 32
template <>
AttrType AttrTypeID<bool>() {
33
  return BOOLEAN;
D
dangqingqing 已提交
34
}
Y
Yi Wang 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47
template <>
AttrType AttrTypeID<int>() {
  return INT;
}
template <>
AttrType AttrTypeID<float>() {
  return FLOAT;
}
template <>
AttrType AttrTypeID<std::string>() {
  return STRING;
}
template <>
D
dangqingqing 已提交
48
AttrType AttrTypeID<std::vector<bool>>() {
49
  return BOOLEANS;
D
dangqingqing 已提交
50 51
}
template <>
Y
Yi Wang 已提交
52 53 54 55 56 57 58 59 60 61 62
AttrType AttrTypeID<std::vector<int>>() {
  return INTS;
}
template <>
AttrType AttrTypeID<std::vector<float>>() {
  return FLOATS;
}
template <>
AttrType AttrTypeID<std::vector<std::string>>() {
  return STRINGS;
}
63 64 65 66
template <>
AttrType AttrTypeID<std::vector<std::pair<int, int>>>() {
  return INT_PAIRS;
}
67 68 69 70
template <>
AttrType AttrTypeID<BlockDesc>() {
  return BLOCK;
}
Y
Yi Wang 已提交
71

Y
Yu Yang 已提交
72
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
Y
Yi Wang 已提交
73
  switch (attr_desc.type()) {
D
dangqingqing 已提交
74
    case framework::AttrType::BOOLEAN: {
D
dangqingqing 已提交
75 76
      return attr_desc.b();
    }
77
    case framework::AttrType::INT: {
Y
Yi Wang 已提交
78 79
      return attr_desc.i();
    }
80
    case framework::AttrType::FLOAT: {
Y
Yi Wang 已提交
81 82
      return attr_desc.f();
    }
83
    case framework::AttrType::STRING: {
Y
Yi Wang 已提交
84 85
      return attr_desc.s();
    }
D
dangqingqing 已提交
86
    case framework::AttrType::BOOLEANS: {
D
dangqingqing 已提交
87 88 89 90 91 92
      std::vector<bool> val(attr_desc.bools_size());
      for (int i = 0; i < attr_desc.bools_size(); ++i) {
        val[i] = attr_desc.bools(i);
      }
      return val;
    }
93
    case framework::AttrType::INTS: {
Y
Yi Wang 已提交
94 95 96 97 98 99
      std::vector<int> val(attr_desc.ints_size());
      for (int i = 0; i < attr_desc.ints_size(); ++i) {
        val[i] = attr_desc.ints(i);
      }
      return val;
    }
100
    case framework::AttrType::FLOATS: {
Y
Yi Wang 已提交
101 102 103 104 105 106
      std::vector<float> val(attr_desc.floats_size());
      for (int i = 0; i < attr_desc.floats_size(); ++i) {
        val[i] = attr_desc.floats(i);
      }
      return val;
    }
107
    case framework::AttrType::STRINGS: {
Y
Yi Wang 已提交
108 109 110 111 112 113
      std::vector<std::string> val(attr_desc.strings_size());
      for (int i = 0; i < attr_desc.strings_size(); ++i) {
        val[i] = attr_desc.strings(i);
      }
      return val;
    }
114
    case framework::AttrType::INT_PAIRS: {
115 116 117 118 119 120 121
      std::vector<std::pair<int, int>> val(attr_desc.int_pairs_size());
      for (int i = 0; i < attr_desc.int_pairs_size(); ++i) {
        val[i].first = attr_desc.int_pairs(i).first();
        val[i].second = attr_desc.int_pairs(i).second();
      }
      return val;
    }
122
    case framework::AttrType::BLOCK: {
F
fengjiayi 已提交
123
      return GetProgramDesc().mutable_blocks(attr_desc.block_idx());
124
    }
Y
Yi Wang 已提交
125 126 127 128 129 130 131
  }
  PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
  return boost::blank();
}

}  // namespace framework
}  // namespace paddle