op_desc.cc 5.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/model_parser/flatbuffers/op_desc.h"

namespace paddle {
namespace lite {
namespace fbs {

template <>
22 23
std::string OpDescView::GetAttr<std::string>(const char* name) const {
  const auto& it = desc_->attrs()->LookupByKey(name);
24 25 26 27 28 29 30
  if (!it->s()) {
    return std::string();
  }
  return it->s()->str();
}

template <>
31 32
std::string OpDescView::GetAttr<std::string>(const std::string& name) const {
  return GetAttr<std::string>(name.c_str());
33 34 35
}

template <>
36
lite::VectorView<std::string, Flatbuffers>
37 38
OpDescView::GetAttr<std::vector<std::string>>(const char* name) const {
  const auto& it = desc_->attrs()->LookupByKey(name);
39
  CHECK(it) << "Attr " << name << "does not exist.";
40
  return VectorView<std::string>(it->strings());
41 42 43
}

template <>
44 45 46
lite::VectorView<std::string, Flatbuffers>
OpDescView::GetAttr<std::vector<std::string>>(const std::string& name) const {
  return GetAttr<std::vector<std::string>>(name.c_str());
47 48
}

49 50 51
#define GET_ATTR_IMPL(T, fb_f__)                                             \
  template <>                                                                \
  typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
52 53
      const char* name) const {                                              \
    const auto& it = desc_->attrs()->LookupByKey(name);                      \
54 55 56 57
    return it->fb_f__();                                                     \
  }                                                                          \
  template <>                                                                \
  typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
58 59
      const std::string& name) const {                                       \
    return GetAttr<T>(name.c_str());                                         \
60 61
  }

62 63
#define GET_ATTRS_IMPL(T, fb_f__)                                            \
  template <>                                                                \
64
  typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
65 66
      const char* name) const {                                              \
    const auto& it = desc_->attrs()->LookupByKey(name);                      \
67 68 69
    return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
  }                                                                          \
  template <>                                                                \
70
  typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
71 72
      const std::string& name) const {                                       \
    return GetAttr<T>(name.c_str());                                         \
73 74 75 76 77 78 79 80 81 82
  }

GET_ATTR_IMPL(int32_t, i);
GET_ATTR_IMPL(int16_t, block_idx);
GET_ATTR_IMPL(float, f);
GET_ATTR_IMPL(bool, b);
GET_ATTR_IMPL(int64_t, l);
GET_ATTRS_IMPL(std::vector<int>, ints);
GET_ATTRS_IMPL(std::vector<float>, floats);
GET_ATTRS_IMPL(std::vector<int64_t>, longs);
83 84 85
#undef GET_ATTR_IMPL
#undef GET_ATTRS_IMPL

86 87 88 89 90 91 92 93 94 95 96 97
#define ATTR_IMPL(T, fb_f__)                                                \
  template <>                                                               \
  T OpDesc::GetAttr<T>(const std::string& name) const {                     \
    return (*GetKeyIterator(name, desc_->attrs))->fb_f__;                   \
  }                                                                         \
  template <>                                                               \
  void OpDesc::SetAttr<T>(const std::string& name, const T& v) {            \
    auto& p = *InsertPair(name,                                             \
                          std::move(std::unique_ptr<proto::OpDesc_::AttrT>( \
                              new proto::OpDesc_::AttrT())),                \
                          &(desc_->attrs));                                 \
    p->fb_f__ = v;                                                          \
98
    p->type = ConvertAttrType(OpDataTypeTrait<T>::AT);                      \
99
    SetKey(name, &p);                                                       \
100 101 102 103 104 105
  }
ATTR_IMPL(int32_t, i);
ATTR_IMPL(int16_t, block_idx);
ATTR_IMPL(float, f);
ATTR_IMPL(bool, b);
ATTR_IMPL(int64_t, l);
106
ATTR_IMPL(std::string, s);
107 108 109
ATTR_IMPL(std::vector<int>, ints);
ATTR_IMPL(std::vector<float>, floats);
ATTR_IMPL(std::vector<int64_t>, longs);
110
ATTR_IMPL(std::vector<std::string>, strings);
111
#undef GET_ATTRS_IMPL
112 113 114 115

}  // namespace fbs
}  // namespace lite
}  // namespace paddle