attr_utils.h 8.8 KB
Newer Older
L
lujiale 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef INC_GRAPH_UTILS_ATTR_UTILS_H_
#define INC_GRAPH_UTILS_ATTR_UTILS_H_

#include <memory>
#include <string>
#include <vector>
#include "graph/detail/attributes_holder.h"
#include "graph/ge_attr_value.h"
#include "graph/types.h"

namespace ge {
class OpDesc;
using OpDescPtr = std::shared_ptr<OpDesc>;
using ConstOpDescPtr = std::shared_ptr<const OpDesc>;

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils {
 public:
  class ConstAttrHolderAdapter;
  class AttrHolderAdapter;
  // Set
  static bool HasAttr(ConstAttrHolderAdapter &&obj, const string &name);

  static bool SetInt(AttrHolderAdapter &&obj, const string &name, const int64_t &value);
  static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector<int64_t> &value);
  static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector<uint32_t> &value);
  static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector<int32_t> &value);
  static bool SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list<int64_t> &&value);

  static bool SetFloat(AttrHolderAdapter &&obj, const string &name, const float &value);
  static bool SetListFloat(AttrHolderAdapter &&obj, const string &name, const vector<float> &value);
  static bool SetBool(AttrHolderAdapter &&obj, const string &name, const bool &value);
  static bool SetListBool(AttrHolderAdapter &&obj, const string &name, const vector<bool> &value);
  static bool SetStr(AttrHolderAdapter &&obj, const string &name, const string &value);
  static bool SetListStr(AttrHolderAdapter &&obj, const string &name, const vector<string> &value);
  static bool SetTensorDesc(AttrHolderAdapter &&obj, const string &name, const GeTensorDesc &value);
  static bool SetListTensorDesc(AttrHolderAdapter &&obj, const string &name, const vector<GeTensorDesc> &value);
  static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensorPtr &value);
  static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const ConstGeTensorPtr &value);
  static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensor &value);
  static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector<GeTensorPtr> &value);
  static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector<ConstGeTensorPtr> &value);
  static bool SetListTensor(AttrHolderAdapter &&obj, const string &name,
                            std::initializer_list<ConstGeTensorPtr> &&value);
  static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector<GeTensor> &value);
  static bool SetGraph(AttrHolderAdapter &&obj, const string &name, const ComputeGraphPtr &value);
  static bool SetListGraph(AttrHolderAdapter &&obj, const string &name, const vector<ComputeGraphPtr> &value);
  static bool SetBytes(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::BYTES &value);
  static bool SetListBytes(AttrHolderAdapter &&obj, const string &name, const vector<GeAttrValue::BYTES> &value);
  static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NamedAttrs &value);
  static bool SetListNamedAttrs(AttrHolderAdapter &&obj, const string &name,
                                const vector<GeAttrValue::NamedAttrs> &value);
  static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<ConstOpDescPtr> &value);
  static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<OpDescPtr> &value);

  // Get
  static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int64_t &value);
  static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int32_t &value);
  static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, uint32_t &value);
  static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<int64_t> &value);
  static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<int32_t> &value);
  static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<uint32_t> &value);
  static bool GetFloat(ConstAttrHolderAdapter &&obj, const string &name, float &value);
  static bool GetListFloat(ConstAttrHolderAdapter &&obj, const string &name, vector<float> &value);
  static bool GetBool(ConstAttrHolderAdapter &&obj, const string &name, bool &value);
  static bool GetListBool(ConstAttrHolderAdapter &&obj, const string &name, vector<bool> &value);
  static bool GetStr(ConstAttrHolderAdapter &&obj, const string &name, string &value);
  static bool GetListStr(ConstAttrHolderAdapter &&obj, const string &name, vector<string> &value);
  static bool GetTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, GeTensorDesc &value);
  static bool GetListTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, vector<GeTensorDesc> &value);
  static bool GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value);
  static bool MutableTensor(AttrHolderAdapter &&obj, const string &name, GeTensorPtr &value);
  static bool GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector<ConstGeTensorPtr> &value);
  static bool MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector<GeTensorPtr> &value);
  static bool GetGraph(ConstAttrHolderAdapter &&obj, const string &name, ComputeGraphPtr &value);
  static bool GetListGraph(ConstAttrHolderAdapter &&obj, const string &name, vector<ComputeGraphPtr> &value);
  static bool GetBytes(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::BYTES &value);
  static bool GetListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<GeAttrValue::BYTES> &value);
  static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NamedAttrs &value);
  static bool GetListNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name,
                                vector<GeAttrValue::NamedAttrs> &value);
  static bool GetListOpDesc(ConstAttrHolderAdapter &&obj, const string &name, vector<OpDescPtr> &value);
  // Value will be moved
  static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer);
  static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer);
  // Value will be moved
  static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name,
                                   vector<Buffer> &listBuffer);
  static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer);

  static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector<vector<int64_t>> &value);
  static bool GetListListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<vector<int64_t>> &value);

  static bool SetListDataType(AttrHolderAdapter &&obj, const string &name, const vector<ge::DataType> &value);
  static bool GetListDataType(ConstAttrHolderAdapter &&obj, const string &name, vector<ge::DataType> &value);

  static bool SetDataType(AttrHolderAdapter &&obj, const string &name, const ge::DataType &value);
  static bool GetDataType(ConstAttrHolderAdapter &&obj, const string &name, ge::DataType &value);

  static OpDescPtr CloneOpDesc(const ConstOpDescPtr &orgOpDesc);

  static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc);

  class AttrHolderAdapter {
   public:
    AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {}
    ~AttrHolderAdapter() {}
    template <class T>
    AttrHolderAdapter(const std::shared_ptr<T> &obj) : obj_(obj.get()) {}
    AttrHolderAdapter(AttrHolder &obj) : obj_(&obj) {}
    operator bool() const { return obj_ != nullptr; }
    AttrHolder *operator->() { return obj_; }
    AttrHolder *get() { return obj_; }

    AttrHolder *obj_;
  };

  class ConstAttrHolderAdapter {
   public:
    ConstAttrHolderAdapter(const AttrHolder *obj) : obj_(obj) {}
    ~ConstAttrHolderAdapter() {}
    template <class T>
    ConstAttrHolderAdapter(const std::shared_ptr<T> obj) : obj_(obj.get()) {}
    ConstAttrHolderAdapter(const AttrHolder &obj) : obj_(&obj) {}
    operator bool() const { return obj_ != nullptr; }
    const AttrHolder *operator->() const { return obj_; }
    const AttrHolder *get() const { return obj_; }

   private:
    const AttrHolder *obj_;
  };
};
}  // namespace ge
#endif  // INC_GRAPH_UTILS_ATTR_UTILS_H_