subgraph_compute.cc 6.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/rknpu/subgraph_compute.h"
#include <sys/time.h>
#include <time.h>
#include <utility>
#include "lite/backends/rknpu/device.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/rknpu/bridges/graph.h"
#include "lite/kernels/rknpu/bridges/paddle_use_bridges.h"
#include "lite/kernels/rknpu/bridges/utility.h"
#include "rknpu/rknpu_pub.h"  // NOLINT

namespace paddle {
namespace lite {
namespace kernels {
namespace rknpu {

31
bool SubgraphEngine::BuildDeviceProgram() {
32 33 34 35 36
  LOG(INFO) << "[RKNPU]:BuildDeviceProgram";
  int status = 0;
  // Convert all of ops and their input vars and weights and added into the NPU
  // RKNPU IR graph
  subgraph::rknpu::Graph graph;
37
  const auto& bridges = subgraph::SubgraphBridgeRegistry::Instance();
38
  if (!origin_program_) {
39 40
    BuildOriginProgram();
  }
41 42
  const auto& insts = origin_program_->instructions(kRootBlockIdx);
  for (auto& inst : insts) {
43 44 45 46 47 48
    auto op = const_cast<OpLite*>(inst.op());
    CHECK(op);
    op->CheckShape();
    op->InferShape();
    std::string op_type = op->op_info()->Type();
    if (!bridges.Exists(op_type, TARGET(kRKNPU))) {
49
      return false;
50 51 52 53 54
    }
    auto kernel = inst.kernel();
    status |= bridges.Select(op_type, TARGET(kRKNPU))(
        reinterpret_cast<void*>(&graph), op, const_cast<KernelBase*>(kernel));
    if (subgraph::CHECK_FAILED(status)) {
55
      return false;
56 57 58 59
    }
  }
  // Collect the valid input and output nodes in the RKNPU IR graph and update
  // the input and output names
60 61
  device_itensors_.clear();
  device_otensors_.clear();
62
  for (size_t i = 0; i < input_names_.size(); i++) {
63 64 65
    CHECK(graph.Has(input_names_[i])) << "[RKNPU] Failed to find input node "
                                      << input_names_[i];
    auto node = graph.Get(input_names_[i]);
66 67
    auto precision = node->precision();
    auto layout = node->layout();
68
    LOG(INFO) << "[RKNPU] Inputs[" << i << "] name: " << input_names_[i]
69 70
              << " precision: " << PrecisionToStr(precision)
              << " layout: " << DataLayoutToStr(layout);
71
    device_itensors_.push_back(node->data());
72
  }
73 74 75 76
  for (size_t i = 0; i < output_names_.size(); i++) {
    CHECK(graph.Has(output_names_[i])) << "[RKNPU] Failed to find output node "
                                       << output_names_[i];
    auto node = graph.Get(output_names_[i]);
77 78
    auto precision = node->precision();
    auto layout = node->layout();
79
    LOG(INFO) << "[RKNPU] Outputs[" << i << "] name: " << output_names_[i]
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
              << " precision: " << PrecisionToStr(precision)
              << " layout: " << DataLayoutToStr(layout);
    // Prepare the device output tensors
    switch (precision) {
      case PRECISION(kFloat):
        origin_otensors_[i]->mutable_data<float>();
        break;
      case PRECISION(kInt8):
        origin_otensors_[i]->mutable_data<int8_t>();
        break;
      case PRECISION(kInt16):
        origin_otensors_[i]->mutable_data<int16_t>();
        break;
      case PRECISION(kInt32):
        origin_otensors_[i]->mutable_data<int32_t>();
        break;
      case PRECISION(kInt64):
        origin_otensors_[i]->mutable_data<int64_t>();
        break;
      default:
100
        LOG(FATAL) << "[RKNPU] " << output_names_[i]
101 102 103 104
                   << " can't mutable data with precision type "
                   << PrecisionToStr(precision);
        break;
    }
105 106 107 108 109 110 111 112
    device_otensors_.push_back(node->data());
  }
  // Create the RKNPU model and set the input and output nodes
  device_program_ = lite::rknpu::Device::Global().Build(
      model_name_, graph.GetHandle(), device_itensors_, device_otensors_);
  if (device_program_ == nullptr) {
    LOG(WARNING) << "[RKNPU] Build model failed!";
    return false;
113
  }
114
  return true;
115 116
}

117
bool SubgraphEngine::LaunchDeviceProgram() {
118 119 120 121
  LOG(INFO) << "[RKNPU]:LaunchDeviceProgram";
  std::vector<rk::nn::InputInfo> inputs;
  std::vector<rk::nn::OutputInfo> outputs;

122 123
  inputs.resize(origin_itensors_.size());
  for (size_t i = 0; i < origin_itensors_.size(); i++) {
124 125 126 127 128 129 130 131 132
    inputs[i].index = i;
    inputs[i].buf = const_cast<void*>(origin_itensors_[i]->raw_data());
    inputs[i].size = origin_itensors_[i]->memory_size();
    inputs[i].pass_through = false;
    inputs[i].type =
        subgraph::rknpu::ToRknpuPrecisionType(origin_itensors_[i]->precision());
    inputs[i].layout = rk::nn::DataLayoutType::NCHW;
  }

133 134
  outputs.resize(origin_otensors_.size());
  for (size_t i = 0; i < origin_otensors_.size(); i++) {
135 136 137 138 139 140 141 142 143
    outputs[i].index = i;
    outputs[i].buf = const_cast<void*>(origin_otensors_[i]->raw_data());
    outputs[i].size = origin_otensors_[i]->memory_size();
    outputs[i].want_float = false;
  }

  device_program_->SetInputs(inputs);
  device_program_->Run();
  device_program_->GetOutputs(outputs);
144
  return true;
145 146 147 148 149 150
}

void SubgraphCompute::PrepareForRun() {
  LOG(INFO) << "[RKNPU]:PrepareForRun";
  auto& param = this->Param<param_t>();
  engine_.reset(new SubgraphEngine(ctx_.get(),
151 152 153
                                   param.block_idx,
                                   param.program_desc,
                                   param.exec_scope,
154
                                   param.input_data_names,
155
                                   param.output_data_names));
156 157 158 159 160 161
  CHECK(engine_);
}

void SubgraphCompute::Run() {
  LOG(INFO) << "[RKNPU]:Run";
  CHECK(engine_);
162
  engine_->Run();
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
}

}  // namespace rknpu
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

REGISTER_LITE_KERNEL(subgraph,
                     kRKNPU,
                     kInt8,
                     kNCHW,
                     paddle::lite::kernels::rknpu::SubgraphCompute,
                     def)
    .BindInput("Inputs",
               {LiteType::GetTensorTy(TARGET(kHost),
                                      PRECISION(kInt8),
                                      DATALAYOUT(kNCHW))})
    .BindOutput("Outputs",
                {LiteType::GetTensorTy(TARGET(kHost),
                                       PRECISION(kInt8),
                                       DATALAYOUT(kNCHW))})
    .Finalize();