utility.cc 4.4 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "lite/kernels/npu/bridges/utility.h"
16
#include <utility>
Y
Yan Chunwei 已提交
17 18 19

namespace paddle {
namespace lite {
20
namespace subgraph {
Y
Yan Chunwei 已提交
21 22
namespace npu {

23 24 25 26 27 28 29 30 31 32 33 34 35
bool HasInputArg(const OpInfo* op_info,
                 const Scope* scope,
                 const std::string& argname) {
  auto iarg_names = op_info->input_argnames();
  if (std::find(iarg_names.begin(), iarg_names.end(), argname) !=
      iarg_names.end()) {
    auto inputs = op_info->Input(argname);
    if (inputs.empty()) {
      return false;
    }
    auto var_name = inputs.front();
    auto var = scope->FindVar(var_name);
    return var != nullptr;
Y
Yan Chunwei 已提交
36
  } else {
37
    return false;
Y
Yan Chunwei 已提交
38 39 40
  }
}

41
ge::DataType CvtPrecisionType(PrecisionType itype) {
Y
Yan Chunwei 已提交
42 43 44 45 46 47 48 49 50 51 52 53
  ge::DataType otype = ge::DT_FLOAT;
  switch (itype) {
    case PRECISION(kFloat):
      otype = ge::DT_FLOAT;
      break;
    case PRECISION(kInt8):
      otype = ge::DT_INT8;
      break;
    case PRECISION(kInt32):
      otype = ge::DT_INT32;
      break;
    default:
54 55
      LOG(FATAL) << "[NPU] Can not convert precision type("
                 << PrecisionToStr(itype) << ") from Lite to NPU";
Y
Yan Chunwei 已提交
56 57 58 59 60
      break;
  }
  return otype;
}

61
ge::Format CvtDataLayoutType(DataLayoutType itype) {
Y
Yan Chunwei 已提交
62 63 64 65 66 67 68
  ge::Format otype = ge::FORMAT_NCHW;
  switch (itype) {
    case DATALAYOUT(kNCHW):
      otype = ge::FORMAT_NCHW;
      break;
    // TODO(hong19860320) support more data layout type
    default:
69
      LOG(FATAL) << "[NPU] Can not convert data layout type("
Y
Yan Chunwei 已提交
70 71 72 73 74 75
                 << DataLayoutToStr(itype) << ") from Lite to NPU";
      break;
  }
  return otype;
}

76
ge::TensorPtr CvtTensor(const Tensor& in_tensor,
77 78 79
                        std::vector<int64_t> out_shape,
                        PrecisionType in_ptype,
                        DataLayoutType in_ltype) {
80 81 82
  const uint8_t* in_data = nullptr;
  auto in_size = in_tensor.dims().production();
  auto in_shape = in_tensor.dims().Vectorize();
Y
Yan Chunwei 已提交
83 84 85 86 87
  if (out_shape.empty()) {
    out_shape = in_shape;
  }
  int in_bytes;
  if (in_ptype == PRECISION(kFloat)) {
88
    in_data = reinterpret_cast<const uint8_t*>(in_tensor.data<float>());
Y
Yan Chunwei 已提交
89 90
    in_bytes = in_size * sizeof(float);
  } else if (in_ptype == PRECISION(kInt32)) {
91
    in_data = reinterpret_cast<const uint8_t*>(in_tensor.data<int32_t>());
Y
Yan Chunwei 已提交
92 93
    in_bytes = in_size * sizeof(int32_t);
  } else if (in_ptype == PRECISION(kInt8)) {
94
    in_data = reinterpret_cast<const uint8_t*>(in_tensor.data<int8_t>());
Y
Yan Chunwei 已提交
95 96
    in_bytes = in_size * sizeof(int8_t);
  } else {
97
    LOG(FATAL) << "[NPU] Unknow precision type " << PrecisionToStr(in_ptype);
Y
Yan Chunwei 已提交
98
  }
99 100
  ge::DataType out_ptype = CvtPrecisionType(in_ptype);
  ge::Format out_ltype = CvtDataLayoutType(in_ltype);
Y
Yan Chunwei 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113

  ge::TensorDesc out_desc(ge::Shape(out_shape), out_ltype, out_ptype);
  CHECK_EQ(out_ltype, ge::FORMAT_NCHW);

  auto out_size = out_desc.GetShape().GetShapeSize();
  CHECK_EQ(out_size, in_size);

  ge::TensorPtr out_tensor = std::make_shared<ge::Tensor>();
  out_tensor->SetTensorDesc(out_desc);
  out_tensor->SetData(in_data, in_bytes);
  return out_tensor;
}

114 115
int CvtActMode(std::string act_type) {
  int act_mode = 1;
116
  if (act_type == "sigmoid") {
117 118 119 120 121
    act_mode = 0;
  } else if (act_type == "relu") {
    act_mode = 1;
  } else if (act_type == "tanh") {
    act_mode = 2;
Z
zhupengyang 已提交
122
  } else if (act_type == "relu_clipped" || act_type == "relu6") {
123
    act_mode = 3;
124 125
  } else if (act_type == "elu") {
    act_mode = 4;
126 127
  } else if (act_type == "leaky_relu") {
    act_mode = 5;
128 129 130 131 132 133
  } else if (act_type == "abs") {
    act_mode = 6;
  } else if (act_type == "softsign") {
    act_mode = 8;
  } else if (act_type == "softplus") {
    act_mode = 9;
134
  } else if (act_type == "hard_sigmoid") {
135 136 137 138 139 140 141 142
    act_mode = 10;
  } else {
    // TODO(hong19860320) support more activation mode
    LOG(FATAL) << "[NPU] Unsupported activation type " << act_type;
  }
  return act_mode;
}

Y
Yan Chunwei 已提交
143
}  // namespace npu
144
}  // namespace subgraph
Y
Yan Chunwei 已提交
145 146
}  // namespace lite
}  // namespace paddle