device.cc 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/backends/npu/device.h"
#include "lite/utils/cp_logging.h"
17
#include "lite/utils/io.h"
18 19 20 21 22

namespace paddle {
namespace lite {
namespace npu {

23 24 25 26
std::shared_ptr<hiai::AiModelMngerClient> Device::Load(
    const std::string& model_name,
    std::vector<char>* model_buffer,
    bool* model_comp) {
27
  // Create a HiAI model manager client to load the HiAI om model
28
  auto model_client = std::make_shared<hiai::AiModelMngerClient>();
29
  if (model_client->Init(nullptr) != hiai::AI_SUCCESS) {
30
    LOG(WARNING) << "[NPU] Init hiai model client failed!";
31 32
    return nullptr;
  }
33 34 35
  // Check HiAI DDK version
  const char* ddk_version = model_client->GetVersion();
  if (ddk_version) {
36
    VLOG(3) << "[NPU] HiAI DDK version: " << ddk_version;
37 38 39 40
  } else {
    LOG(WARNING) << "[NPU] Unable to get HiAI DDK version!";
  }
  // Check model compatibility
41 42
  auto model_desc = std::make_shared<hiai::AiModelDescription>(
      model_name, freq_level(), framework_type(), model_type(), device_type());
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
  model_desc->SetModelBuffer(
      reinterpret_cast<const void*>(model_buffer->data()),
      model_buffer->size());
  if (!*model_comp &&
      model_client->CheckModelCompatibility(*model_desc, *model_comp) !=
          hiai::AI_SUCCESS) {
    *model_comp = false;
    VLOG(3) << "[NPU] model is NOT compatiblitiable, setting model_comp to "
            << *model_comp;
  } else {
    *model_comp = true;
    VLOG(3) << "[NPU] model is compatiblitiable, setting model_comp to "
            << *model_comp;
  }
  // Rebuild and write the data of the compatible model to the model buffer
  if (!*model_comp) {
    std::shared_ptr<hiai::AiModelBuilder> model_builder =
        std::make_shared<hiai::AiModelBuilder>(model_client);
    hiai::MemBuffer* org_model_buffer = model_builder->InputMemBufferCreate(
        reinterpret_cast<void*>(model_buffer->data()), model_buffer->size());
    if (org_model_buffer) {
      std::vector<hiai::MemBuffer*> org_model_buffers;
      org_model_buffers.push_back(org_model_buffer);
      hiai::MemBuffer* new_model_buffer = model_builder->OutputMemBufferCreate(
          framework_type(), org_model_buffers);
      // VLOG(3) << "[NPU] new model buffer memeory size is " <<
      // new_model_buffer->GetMemBufferSize();
      if (new_model_buffer) {
        uint32_t new_model_size = 0;
        if (model_builder->BuildModel(org_model_buffers,
                                      new_model_buffer,
                                      new_model_size) == hiai::AI_SUCCESS) {
          // need to change to new_model_size as GetMemBufferSize is not
          // correct.
          model_buffer->resize(new_model_size);
          memcpy(reinterpret_cast<void*>(model_buffer->data()),
                 new_model_buffer->GetMemBufferData(),
                 new_model_size);
          // Reset the model buffer
          model_desc->SetModelBuffer(
              reinterpret_cast<const void*>(model_buffer->data()),
              model_buffer->size());
          VLOG(3) << "[NPU] Rebuild the compatible model done.";
        } else {
          LOG(WARNING) << "[NPU] Rebuild the compatible model failed!";
        }
        model_builder->MemBufferDestroy(new_model_buffer);
      } else {
        LOG(WARNING) << "[NPU] OutputMemBufferCreate failed!";
      }
      model_builder->MemBufferDestroy(org_model_buffer);
    } else {
      LOG(WARNING) << "[NPU] InputMemBufferCreate failed!";
    }
  }
  // Load the compatible model
  std::vector<std::shared_ptr<hiai::AiModelDescription>> model_descs{
      model_desc};
101 102 103 104
  if (model_client->Load(model_descs) != hiai::AI_SUCCESS) {
    LOG(WARNING) << "[NPU] AiModelMngerClient load model failed!";
    return nullptr;
  }
105
  VLOG(3) << "[NPU] Load model done.";
106 107 108
  return model_client;
}

109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
bool Device::Build(std::vector<ge::Operator>& input_nodes,   // NOLINT
                   std::vector<ge::Operator>& output_nodes,  // NOLINT
                   std::vector<char>* model_buffer) {
  // Convert the HiAI IR graph to the HiAI om model
  ge::Graph ir_graph("graph");
  ir_graph.SetInputs(input_nodes).SetOutputs(output_nodes);
  ge::Model om_model("model", "model");
  om_model.SetGraph(ir_graph);

  // Build the HiAI om model, serialize and output it to the om buffer
  domi::HiaiIrBuild ir_build;
  domi::ModelBufferData om_buffer;
  if (!ir_build.CreateModelBuff(om_model, om_buffer)) {
    LOG(WARNING) << "[NPU] CreateModelBuff failed!";
    return false;
  }
  if (!ir_build.BuildIRModel(om_model, om_buffer)) {
    LOG(WARNING) << "[NPU] BuildIRModel failed!";
    ir_build.ReleaseModelBuff(om_buffer);
    return false;
  }
  model_buffer->resize(om_buffer.length);
  memcpy(reinterpret_cast<void*>(model_buffer->data()),
         reinterpret_cast<void*>(om_buffer.data),
         om_buffer.length);
  ir_build.ReleaseModelBuff(om_buffer);
  VLOG(3) << "[NPU] Build model done.";
  return true;
}

139 140 141
}  // namespace npu
}  // namespace lite
}  // namespace paddle