op_desc.cc 3.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/model_parser/flatbuffers/op_desc.h"

namespace paddle {
namespace lite {
namespace fbs {

template <>
std::string OpDesc::GetAttr<std::string>(const std::string& name) const {
  const auto& it = desc_->attrs()->LookupByKey(name.c_str());
  if (!it->s()) {
    return std::string();
  }
  return it->s()->str();
}

template <>
std::string OpDesc::GetAttr<std::string>(size_t idx) const {
  const auto& it = desc_->attrs()->Get(idx);
  if (!it->s()) {
    return std::string();
  }
  return it->s()->str();
}

template <>
40 41
lite::VectorView<std::string, Flatbuffers>
OpDesc::GetAttr<std::vector<std::string>>(const std::string& name) const {
42 43
  const auto& it = desc_->attrs()->LookupByKey(name.c_str());
  CHECK(it) << "Attr " << name << "does not exist.";
44
  return VectorView<std::string>(it->strings());
45 46 47
}

template <>
48
VectorView<std::string, Flatbuffers> OpDesc::GetAttr<std::vector<std::string>>(
49 50 51
    size_t idx) const {
  const auto& it = desc_->attrs()->Get(idx);
  CHECK(it) << "Attr " << idx << "does not exist.";
52
  return VectorView<std::string>(it->strings());
53 54
}

55 56 57 58 59 60 61 62 63 64 65 66
#define GET_ATTR_IMPL(T, fb_f__)                                         \
  template <>                                                            \
  typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDesc::GetAttr<T>( \
      const std::string& name) const {                                   \
    const auto& it = desc_->attrs()->LookupByKey(name.c_str());          \
    return it->fb_f__();                                                 \
  }                                                                      \
  template <>                                                            \
  typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDesc::GetAttr<T>( \
      size_t idx) const {                                                \
    const auto& it = desc_->attrs()->Get(idx);                           \
    return it->fb_f__();                                                 \
67 68
  }

69 70 71 72 73 74 75 76 77 78 79 80
#define GET_ATTRS_IMPL(T, fb_f__)                                            \
  template <>                                                                \
  typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDesc::GetAttr<T>(     \
      const std::string& name) const {                                       \
    const auto& it = desc_->attrs()->LookupByKey(name.c_str());              \
    return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
  }                                                                          \
  template <>                                                                \
  typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDesc::GetAttr<T>(     \
      size_t idx) const {                                                    \
    const auto& it = desc_->attrs()->Get(idx);                               \
    return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
81 82 83 84 85 86 87 88 89 90 91 92 93 94
  }

GET_ATTR_IMPL(int32_t, i);
GET_ATTR_IMPL(int16_t, block_idx);
GET_ATTR_IMPL(float, f);
GET_ATTR_IMPL(bool, b);
GET_ATTR_IMPL(int64_t, l);
GET_ATTRS_IMPL(std::vector<int>, ints);
GET_ATTRS_IMPL(std::vector<float>, floats);
GET_ATTRS_IMPL(std::vector<int64_t>, longs);

}  // namespace fbs
}  // namespace lite
}  // namespace paddle