ge_local_ops_kernel_info.cc 7.5 KB
Newer Older
L
lujiale 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h"

#include <memory>

#include "common/constant/constant.h"
#include "framework/common/debug/ge_log.h"
#include "common/ge_inner_error_codes.h"
#include "common/ge/ge_util.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "op/op_factory.h"
#include "proto/task.pb.h"

namespace {
const char *const kConstantOpType = "Constant";
const char *const kConstantOpAttrName = "value";
}  // namespace
namespace ge {
namespace ge_local {
using domi::TaskDef;
using std::map;
using std::string;
using std::vector;

Status GeLocalOpsKernelInfoStore::Initialize(const map<string, string> &options) {
  GELOGI("GeLocalOpsKernelInfoStore init start.");

  OpInfo default_op_info = {.engine = kGeLocalEngineName,
                            .opKernelLib = kGeLocalOpKernelLibName,
                            .computeCost = 0,
                            .flagPartial = false,
                            .flagAsync = false,
                            .isAtomic = false};
  // Init op_info_map_
  auto all_ops = OpFactory::Instance().GetAllOps();
  for (auto &op : all_ops) {
    op_info_map_[op] = default_op_info;
  }

  GELOGI("GeLocalOpsKernelInfoStore inited success. op num=%zu", op_info_map_.size());

  return SUCCESS;
}

Status GeLocalOpsKernelInfoStore::Finalize() {
  op_info_map_.clear();
  return SUCCESS;
}

Status GeLocalOpsKernelInfoStore::CalcOpRunningParam(Node &ge_node) {
  OpDescPtr op_desc = ge_node.GetOpDesc();
  if (op_desc == nullptr) {
    GELOGE(FAILED, "CalcOpRunningParam failed, as op desc is null");
    return FAILED;
  }
  const string node_name = ge_node.GetName();
  const string node_type = ge_node.GetType();
  size_t output_size = op_desc->GetOutputsSize();
  GELOGD("Calc op[%s:%s] op running param, output size=%zu.", node_name.c_str(), node_type.c_str(), output_size);

  for (size_t i = 0; i < output_size; ++i) {
    GeTensorDesc output_tensor = op_desc->GetOutputDesc(static_cast<uint32_t>(i));
    Format format = output_tensor.GetFormat();
    DataType data_type = output_tensor.GetDataType();

    uint32_t mem_size = 0;
    graphStatus graph_status = TensorUtils::GetSize(output_tensor, mem_size);
    // If mem size has been set, no need reset.
    if ((graph_status == GRAPH_SUCCESS) && (mem_size > 0) && (data_type != DT_STRING)) {
      GELOGD("Op[%s:%s] out[%zu] mem size has been set, no need calc again, format=%s, data_type=%s, mem_size=%u.",
             node_name.c_str(), node_type.c_str(), i, TypeUtils::FormatToSerialString(format).c_str(),
             TypeUtils::DataTypeToSerialString(data_type).c_str(), mem_size);
      continue;
    }

    int64_t output_mem_size = 0;
    GeShape output_shape = output_tensor.GetShape();
    if ((node_type == kConstantOpType) && (data_type == DT_STRING)) {
      graph_status = CalcConstantStrMemSize(op_desc, output_mem_size);
    } else {
      graph_status = TensorUtils::CalcTensorMemSize(output_shape, format, data_type, output_mem_size);
    }

    if (graph_status != GRAPH_SUCCESS) {
      GELOGE(FAILED, "Calc op[%s:%s] out[%zu] mem size failed, format=%s, data_type=%s, error=%u.", node_name.c_str(),
             node_type.c_str(), i, TypeUtils::FormatToSerialString(format).c_str(),
             TypeUtils::DataTypeToSerialString(data_type).c_str(), graph_status);
      return FAILED;
    }

    if (output_mem_size < 0) {
      GELOGE(FAILED,
             "Calc op[%s:%s] out[%zu] mem size is negative(not support),"
             " format=%s, data_type=%s, mem_size=%ld.",
             node_name.c_str(), node_type.c_str(), i, TypeUtils::FormatToSerialString(format).c_str(),
             TypeUtils::DataTypeToSerialString(data_type).c_str(), output_mem_size);
      return FAILED;
    }
    GELOGI(
        "Calc op[%s:%s] out[%zu] mem size is %ld,"
        " format=%s, data_type=%s.",
        node_name.c_str(), node_type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(),
        TypeUtils::DataTypeToSerialString(data_type).c_str());

    if (output_mem_size > static_cast<int64_t>(UINT_MAX)) {
      GELOGE(FAILED,
             "Calc op[%s:%s] out[%zu] mem size failed, as GE need data, "
             "type is uint32, but output_mem_size[%ld] is overflow.",
             node_name.c_str(), node_type.c_str(), i, output_mem_size);
      return FAILED;
    }
    TensorUtils::SetSize(output_tensor, static_cast<uint32_t>(output_mem_size));

    graph_status = op_desc->UpdateOutputDesc(static_cast<uint32_t>(i), output_tensor);
    if (graph_status != GRAPH_SUCCESS) {
      GELOGE(FAILED, "Update op[%s:%s] out[%zu] desc failed, format=%s, data_type=%s, error=%u.", node_name.c_str(),
             node_type.c_str(), i, TypeUtils::FormatToSerialString(format).c_str(),
             TypeUtils::DataTypeToSerialString(data_type).c_str(), graph_status);
      return FAILED;
    }
  }
  GELOGD("Calc op[%s:%s] running param success.", node_name.c_str(), node_type.c_str());
  return SUCCESS;
}

Status GeLocalOpsKernelInfoStore::CalcConstantStrMemSize(const OpDescPtr &op_desc, int64_t &mem_size) {
  if (op_desc == nullptr) {
    GELOGE(FAILED, "CalcConstantStrMemSize failed, as op desc is null");
    return FAILED;
  }
  ConstGeTensorPtr value = MakeShared<const GeTensor>();
  if (value == nullptr) {
    GELOGE(FAILED, "make shared ConstGeTensor exception.");
    return FAILED;
  }
  // Constant op attr name is "value"
  if (!AttrUtils::GetTensor(op_desc, kConstantOpAttrName, value)) {
    GELOGE(FAILED, "Get Constant op attr value failed");
    return FAILED;
  }
  mem_size = static_cast<int64_t>(value->GetData().size());
  return GRAPH_SUCCESS;
}

void GeLocalOpsKernelInfoStore::GetAllOpsKernelInfo(map<string, OpInfo> &infos) const { infos = op_info_map_; }

Status GeLocalOpsKernelInfoStore::GenerateTask(const Node &node, RunContext &context, vector<TaskDef> &tasks) {
  string name = node.GetName();
  string type = node.GetType();
  GELOGD("Ge local generate task for node:%s(%s) begin, tasks.size()=%zu.", name.c_str(), type.c_str(), tasks.size());

  auto op = OpFactory::Instance().CreateOp(node, context);
  if (op == nullptr) {
    GELOGE(FAILED, "CreateOp for node:%s(%s) failed.", name.c_str(), type.c_str());
    return FAILED;
  }

  Status ret = op->Run();
  if (ret != SUCCESS) {
    GELOGE(ret, "Node:%s(%s) op run failed.", name.c_str(), type.c_str());
    return ret;
  }
  GELOGI("Ge local generate task for node:%s(%s) end, tasks.size()=%zu.", name.c_str(), type.c_str(), tasks.size());
  return ret;
}

bool GeLocalOpsKernelInfoStore::CheckSupported(const OpDescPtr &op_desc, std::string &) const {
  if (op_desc == nullptr) {
    return false;
  }
  return op_info_map_.count(op_desc->GetType()) > 0;
}

Status GeLocalOpsKernelInfoStore::CreateSession(const map<string, string> &session_options) {
  // Do nothing
  return SUCCESS;
}

Status GeLocalOpsKernelInfoStore::DestroySession(const map<string, string> &session_options) {
  // Do nothing
  return SUCCESS;
}
}  // namespace ge_local
}  // namespace ge