ge_model.cc 2.8 KB
Newer Older
L
lujiale 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "model/ge_model.h"

#include <utility>

#include "common/debug/log.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/attr_utils.h"

namespace ge {
void GeModel::Init() {
  (void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0);
  (void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0);
  (void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0);
  (void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0);
  (void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI);
  version_ = 0;
  // default attrSize = 5
}

GeModel::GeModel() {
  attrs_.InitDefault();
  Init();
}

const Graph &GeModel::GetGraph() const { return this->graph_; }

std::shared_ptr<domi::ModelTaskDef> GeModel::GetModelTaskDefPtr() const { return this->task_; }

const TBEKernelStore &GeModel::GetTBEKernelStore() const { return this->tbe_kernal_store_; }

Buffer GeModel::GetWeight() const { return this->weights_buffer_; }

std::string GeModel::GetName() const { return this->name_; }

uint32_t GeModel::GetVersion() const { return this->version_; }

std::string GeModel::GetPlatformVersion() const { return this->platform_version_; }

uint8_t GeModel::GetPlatformType() const { return this->platform_type_; }

void GeModel::SetGraph(const Graph &graph) { this->graph_ = graph; }

void GeModel::SetModelTaskDef(const std::shared_ptr<domi::ModelTaskDef> &task) { this->task_ = task; }

void GeModel::SetTBEKernelStore(const TBEKernelStore &tbe_kernal_store) {
  this->tbe_kernal_store_ = tbe_kernal_store;
}

void GeModel::SetWeight(const Buffer &weights_buffer) { this->weights_buffer_ = weights_buffer; }

void GeModel::SetName(const std::string &name) { this->name_ = name; }

void GeModel::SetVersion(uint32_t version) { this->version_ = version; }

void GeModel::SetPlatformVersion(const std::string &platform_version) { this->platform_version_ = platform_version; }

void GeModel::SetPlatformType(uint8_t platform_type) { this->platform_type_ = platform_type; }

void GeModel::SetAttr(const ProtoAttrMapHelper &attrs) { attrs_ = attrs; }

ProtoAttrMapHelper GeModel::MutableAttrMap() { return attrs_; }

ConstProtoAttrMapHelper GeModel::GetAttrMap() const {
  return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg());
}
}  // namespace ge