/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef INC_GRAPH_UTILS_ATTR_UTILS_H_ #define INC_GRAPH_UTILS_ATTR_UTILS_H_ #include #include #include #include "graph/detail/attributes_holder.h" #include "graph/ge_attr_value.h" #include "graph/types.h" namespace ge { class OpDesc; using OpDescPtr = std::shared_ptr; using ConstOpDescPtr = std::shared_ptr; class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { public: class ConstAttrHolderAdapter; class AttrHolderAdapter; // Set static bool HasAttr(ConstAttrHolderAdapter &&obj, const string &name); static bool SetInt(AttrHolderAdapter &&obj, const string &name, const int64_t &value); static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list &&value); static bool SetFloat(AttrHolderAdapter &&obj, const string &name, const float &value); static bool SetListFloat(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool SetBool(AttrHolderAdapter &&obj, const string &name, const bool &value); static bool SetListBool(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool SetStr(AttrHolderAdapter &&obj, const string &name, const string &value); static bool SetListStr(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool SetTensorDesc(AttrHolderAdapter &&obj, const string &name, const GeTensorDesc &value); static bool SetListTensorDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensorPtr &value); static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const ConstGeTensorPtr &value); static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensor &value); static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, std::initializer_list &&value); static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool SetGraph(AttrHolderAdapter &&obj, const string &name, const ComputeGraphPtr &value); static bool SetListGraph(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool SetBytes(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::BYTES &value); static bool SetListBytes(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NamedAttrs &value); static bool SetListNamedAttrs(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); // Get static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int64_t &value); static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int32_t &value); static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, uint32_t &value); static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value); static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value); static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value); static bool GetFloat(ConstAttrHolderAdapter &&obj, const string &name, float &value); static bool GetListFloat(ConstAttrHolderAdapter &&obj, const string &name, vector &value); static bool GetBool(ConstAttrHolderAdapter &&obj, const string &name, bool &value); static bool GetListBool(ConstAttrHolderAdapter &&obj, const string &name, vector &value); static bool GetStr(ConstAttrHolderAdapter &&obj, const string &name, string &value); static bool GetListStr(ConstAttrHolderAdapter &&obj, const string &name, vector &value); static bool GetTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, GeTensorDesc &value); static bool GetListTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, vector &value); static bool GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value); static bool MutableTensor(AttrHolderAdapter &&obj, const string &name, GeTensorPtr &value); static bool GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector &value); static bool MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector &value); static bool GetGraph(ConstAttrHolderAdapter &&obj, const string &name, ComputeGraphPtr &value); static bool GetListGraph(ConstAttrHolderAdapter &&obj, const string &name, vector &value); static bool GetBytes(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::BYTES &value); static bool GetListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector &value); static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NamedAttrs &value); static bool GetListNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, vector &value); static bool GetListOpDesc(ConstAttrHolderAdapter &&obj, const string &name, vector &value); // Value will be moved static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer); // Value will be moved static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector &listBuffer); static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector &listBuffer); static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector> &value); static bool GetListListInt(ConstAttrHolderAdapter &&obj, const string &name, vector> &value); static bool SetListDataType(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool GetListDataType(ConstAttrHolderAdapter &&obj, const string &name, vector &value); static bool SetDataType(AttrHolderAdapter &&obj, const string &name, const ge::DataType &value); static bool GetDataType(ConstAttrHolderAdapter &&obj, const string &name, ge::DataType &value); static OpDescPtr CloneOpDesc(const ConstOpDescPtr &orgOpDesc); static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc); class AttrHolderAdapter { public: AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {} ~AttrHolderAdapter() {} template AttrHolderAdapter(const std::shared_ptr &obj) : obj_(obj.get()) {} AttrHolderAdapter(AttrHolder &obj) : obj_(&obj) {} operator bool() const { return obj_ != nullptr; } AttrHolder *operator->() { return obj_; } AttrHolder *get() { return obj_; } AttrHolder *obj_; }; class ConstAttrHolderAdapter { public: ConstAttrHolderAdapter(const AttrHolder *obj) : obj_(obj) {} ~ConstAttrHolderAdapter() {} template ConstAttrHolderAdapter(const std::shared_ptr obj) : obj_(obj.get()) {} ConstAttrHolderAdapter(const AttrHolder &obj) : obj_(&obj) {} operator bool() const { return obj_ != nullptr; } const AttrHolder *operator->() const { return obj_; } const AttrHolder *get() const { return obj_; } private: const AttrHolder *obj_; }; }; } // namespace ge #endif // INC_GRAPH_UTILS_ATTR_UTILS_H_