attribute.cc 2.9 KB
Newer Older
Y
Yi Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/attribute.h"

#include <vector>

namespace paddle {
namespace framework {

F
fengjiayi 已提交
22 23 24 25 26 27 28 29 30
static ProgramDesc* g_program_desc = nullptr;

ProgramDesc& GetProgramDesc() {
  if (g_program_desc == nullptr) {
    g_program_desc = new ProgramDesc();
  }
  return *g_program_desc;
}

Y
Yi Wang 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
template <>
AttrType AttrTypeID<int>() {
  return INT;
}
template <>
AttrType AttrTypeID<float>() {
  return FLOAT;
}
template <>
AttrType AttrTypeID<std::string>() {
  return STRING;
}
template <>
AttrType AttrTypeID<std::vector<int>>() {
  return INTS;
}
template <>
AttrType AttrTypeID<std::vector<float>>() {
  return FLOATS;
}
template <>
AttrType AttrTypeID<std::vector<std::string>>() {
  return STRINGS;
}
55 56 57 58
template <>
AttrType AttrTypeID<std::vector<std::pair<int, int>>>() {
  return INT_PAIRS;
}
59 60 61 62
template <>
AttrType AttrTypeID<BlockDesc>() {
  return BLOCK;
}
Y
Yi Wang 已提交
63

Y
Yu Yang 已提交
64
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
Y
Yi Wang 已提交
65
  switch (attr_desc.type()) {
66
    case framework::AttrType::INT: {
Y
Yi Wang 已提交
67 68
      return attr_desc.i();
    }
69
    case framework::AttrType::FLOAT: {
Y
Yi Wang 已提交
70 71
      return attr_desc.f();
    }
72
    case framework::AttrType::STRING: {
Y
Yi Wang 已提交
73 74
      return attr_desc.s();
    }
75
    case framework::AttrType::INTS: {
Y
Yi Wang 已提交
76 77 78 79 80 81
      std::vector<int> val(attr_desc.ints_size());
      for (int i = 0; i < attr_desc.ints_size(); ++i) {
        val[i] = attr_desc.ints(i);
      }
      return val;
    }
82
    case framework::AttrType::FLOATS: {
Y
Yi Wang 已提交
83 84 85 86 87 88
      std::vector<float> val(attr_desc.floats_size());
      for (int i = 0; i < attr_desc.floats_size(); ++i) {
        val[i] = attr_desc.floats(i);
      }
      return val;
    }
89
    case framework::AttrType::STRINGS: {
Y
Yi Wang 已提交
90 91 92 93 94 95
      std::vector<std::string> val(attr_desc.strings_size());
      for (int i = 0; i < attr_desc.strings_size(); ++i) {
        val[i] = attr_desc.strings(i);
      }
      return val;
    }
96
    case framework::AttrType::INT_PAIRS: {
97 98 99 100 101 102 103
      std::vector<std::pair<int, int>> val(attr_desc.int_pairs_size());
      for (int i = 0; i < attr_desc.int_pairs_size(); ++i) {
        val[i].first = attr_desc.int_pairs(i).first();
        val[i].second = attr_desc.int_pairs(i).second();
      }
      return val;
    }
104
    case framework::AttrType::BLOCK: {
F
fengjiayi 已提交
105
      return GetProgramDesc().mutable_blocks(attr_desc.block_idx());
106
    }
Y
Yi Wang 已提交
107 108 109 110 111 112 113
  }
  PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
  return boost::blank();
}

}  // namespace framework
}  // namespace paddle