utils.h 3.5 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/core/op_lite.h"
#include "lite/core/target_wrapper.h"
#include "lite/core/tensor.h"

namespace paddle {
namespace lite {
Z
zhupengyang 已提交
28
namespace kernels {
Y
Yan Chunwei 已提交
29
namespace npu {
Z
zhupengyang 已提交
30
namespace bridges {
Y
Yan Chunwei 已提交
31

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
class OpList {
 public:
  static OpList& Global() {
    static thread_local OpList x;
    return x;
  }
  void clear() { lists_.clear(); }
  void add(std::shared_ptr<ge::Operator> p) { lists_.push_back(p); }

 private:
  std::vector<std::shared_ptr<ge::Operator>> lists_;
};

// Build HIAI IR graph to om model, and store om model data into lite tensor
bool BuildModel(std::vector<ge::Operator>& inputs,   // NOLINT
                std::vector<ge::Operator>& outputs,  // NOLINT
                lite::Tensor* model_data);

Y
Yan Chunwei 已提交
50 51 52 53 54 55 56 57 58 59 60 61
std::string UniqueName(const std::string& prefix);

ge::DataType PrecisionConverter(PrecisionType itype);

ge::Format DataLayoutConverter(DataLayoutType itype);

ge::TensorPtr CvtFromLiteTensor(Tensor* in_tensor,
                                std::vector<int64_t> out_shape = {},
                                PrecisionType in_ptype = PRECISION(kFloat),
                                DataLayoutType in_ltype = DATALAYOUT(kNCHW));

template <typename T>
Y
Yan Chunwei 已提交
62 63
ge::TensorPtr CreateTensorAndFillData(std::vector<T> data,
                                      std::vector<int64_t> shape = {},
Y
Yan Chunwei 已提交
64 65 66 67 68 69 70 71 72 73 74 75
                                      ge::Format format = ge::FORMAT_NCHW) {
  const std::type_info& info = typeid(T);
  ge::DataType type = ge::DT_FLOAT;
  if (info == typeid(float)) {
    type = ge::DT_FLOAT;
  } else if (info == typeid(int8_t)) {
    type = ge::DT_INT8;
  } else if (info == typeid(int32_t)) {
    type = ge::DT_INT32;
  } else {
    LOG(FATAL) << "Unknow value type " << info.name();
  }
Y
Yan Chunwei 已提交
76 77 78 79 80 81 82 83 84
  if (shape.empty()) {
    shape = {static_cast<int64_t>(data.size())};
  } else {
    int size = 1;
    for (auto i : shape) {
      size *= i;
    }
    CHECK_EQ(data.size(), size);
  }
Y
Yan Chunwei 已提交
85 86 87
  ge::TensorDesc desc(ge::Shape(shape), format, type);
  ge::TensorPtr tensor = std::make_shared<ge::Tensor>();
  tensor->SetTensorDesc(desc);
Y
Yan Chunwei 已提交
88 89 90 91 92 93 94 95 96 97
  tensor->SetData(reinterpret_cast<uint8_t*>(data.data()),
                  data.size() * sizeof(T));
  return tensor;
}

template <typename T>
ge::TensorPtr CreateTensorAndFillData(T value,
                                      std::vector<int64_t> shape = {1},
                                      ge::Format format = ge::FORMAT_NCHW) {
  int64_t size = 1;
Y
Yan Chunwei 已提交
98
  for (auto i : shape) {
Y
Yan Chunwei 已提交
99
    size *= i;
Y
Yan Chunwei 已提交
100
  }
Y
Yan Chunwei 已提交
101 102
  std::vector<T> data(size, value);
  return CreateTensorAndFillData(data, shape, format);
Y
Yan Chunwei 已提交
103 104 105 106 107 108
}

bool HasInputArg(const OpInfo* op_info,
                 const Scope* scope,
                 const std::string& argname);

Z
zhupengyang 已提交
109
}  // namespace bridges
Y
Yan Chunwei 已提交
110
}  // namespace npu
Z
zhupengyang 已提交
111
}  // namespace kernels
Y
Yan Chunwei 已提交
112 113
}  // namespace lite
}  // namespace paddle