op_desc.cc 4.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/model_parser/flatbuffers/op_desc.h"

namespace paddle {
namespace lite {
namespace fbs {

template <>
22
std::string OpDescView::GetAttr<std::string>(const std::string& name) const {
23 24 25 26 27 28 29 30
  const auto& it = desc_->attrs()->LookupByKey(name.c_str());
  if (!it->s()) {
    return std::string();
  }
  return it->s()->str();
}

template <>
31
std::string OpDescView::GetAttr<std::string>(size_t idx) const {
32 33 34 35 36 37 38 39
  const auto& it = desc_->attrs()->Get(idx);
  if (!it->s()) {
    return std::string();
  }
  return it->s()->str();
}

template <>
40
lite::VectorView<std::string, Flatbuffers>
41
OpDescView::GetAttr<std::vector<std::string>>(const std::string& name) const {
42 43
  const auto& it = desc_->attrs()->LookupByKey(name.c_str());
  CHECK(it) << "Attr " << name << "does not exist.";
44
  return VectorView<std::string>(it->strings());
45 46 47
}

template <>
48 49
VectorView<std::string, Flatbuffers>
OpDescView::GetAttr<std::vector<std::string>>(size_t idx) const {
50 51
  const auto& it = desc_->attrs()->Get(idx);
  CHECK(it) << "Attr " << idx << "does not exist.";
52
  return VectorView<std::string>(it->strings());
53 54
}

55 56 57 58 59 60 61 62 63 64 65 66
#define GET_ATTR_IMPL(T, fb_f__)                                             \
  template <>                                                                \
  typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
      const std::string& name) const {                                       \
    const auto& it = desc_->attrs()->LookupByKey(name.c_str());              \
    return it->fb_f__();                                                     \
  }                                                                          \
  template <>                                                                \
  typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
      size_t idx) const {                                                    \
    const auto& it = desc_->attrs()->Get(idx);                               \
    return it->fb_f__();                                                     \
67 68
  }

69 70
#define GET_ATTRS_IMPL(T, fb_f__)                                            \
  template <>                                                                \
71
  typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
72 73 74 75 76
      const std::string& name) const {                                       \
    const auto& it = desc_->attrs()->LookupByKey(name.c_str());              \
    return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
  }                                                                          \
  template <>                                                                \
77
  typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
78 79 80
      size_t idx) const {                                                    \
    const auto& it = desc_->attrs()->Get(idx);                               \
    return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
81 82 83 84 85 86 87 88 89 90
  }

GET_ATTR_IMPL(int32_t, i);
GET_ATTR_IMPL(int16_t, block_idx);
GET_ATTR_IMPL(float, f);
GET_ATTR_IMPL(bool, b);
GET_ATTR_IMPL(int64_t, l);
GET_ATTRS_IMPL(std::vector<int>, ints);
GET_ATTRS_IMPL(std::vector<float>, floats);
GET_ATTRS_IMPL(std::vector<int64_t>, longs);
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
#undef GET_ATTR_IMPL
#undef GET_ATTRS_IMPL

#define ATTR_IMPL(T, fb_f__)                                  \
  template <>                                                 \
  T OpDesc::GetAttr<T>(const std::string& name) const {       \
    return (*GetKeyIterator(name, desc_->attrs))->fb_f__;     \
  }                                                           \
  template <>                                                 \
  void OpDesc::SetAttr(const std::string& name, const T& v) { \
    (*GetKeyIterator(name, desc_->attrs))->fb_f__ = v;        \
  }
ATTR_IMPL(int32_t, i);
ATTR_IMPL(int16_t, block_idx);
ATTR_IMPL(float, f);
ATTR_IMPL(bool, b);
ATTR_IMPL(int64_t, l);
ATTR_IMPL(std::vector<int>, ints);
ATTR_IMPL(std::vector<float>, floats);
ATTR_IMPL(std::vector<int64_t>, longs);
#undef GET_ATTRS_IMPL
112 113 114 115

}  // namespace fbs
}  // namespace lite
}  // namespace paddle