trt_helper.h 2.0 KB
Newer Older
W
Wilber 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <NvInfer.h>
#include <NvInferRuntime.h>
#include <NvInferRuntimeCommon.h>

#include "glog/logging.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"

namespace infrt {
namespace kernel {
namespace tensorrt {

W
Wilber 已提交
31
static nvinfer1::DataType TensorTypeToWeightType(::phi::DataType tensor_type) {
W
Wilber 已提交
32
  switch (tensor_type) {
W
Wilber 已提交
33
    case ::phi::DataType::FLOAT32:
W
Wilber 已提交
34
      return nvinfer1::DataType::kFLOAT;
W
Wilber 已提交
35
    case ::phi::DataType::INT32:
W
Wilber 已提交
36
      return nvinfer1::DataType::kINT32;
W
Wilber 已提交
37
    case ::phi::DataType::FLOAT16:
W
Wilber 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
      return nvinfer1::DataType::kHALF;
    default:
      llvm_unreachable("should not reach here");
  }
}

static nvinfer1::Dims ArrayAttrToNvDims(const mlir::ArrayAttr& int_array_attr) {
  nvinfer1::Dims dims;
  dims.nbDims = int_array_attr.size();
  CHECK(!int_array_attr.empty());
  CHECK(int_array_attr[0].getType().isIntOrIndex());
  for (int i = 0; i < dims.nbDims; ++i) {
    dims.d[i] = int_array_attr[i].cast<mlir::IntegerAttr>().getInt();
  }
  return dims;
}

W
Wilber 已提交
55
static nvinfer1::Weights TensorToWeights(::phi::DenseTensor* tensor) {
W
Wilber 已提交
56 57 58 59 60 61 62 63 64 65 66
  CHECK_NOTNULL(tensor);
  nvinfer1::Weights ret;
  ret.type = TensorTypeToWeightType(tensor->dtype());
  ret.count = tensor->numel();
  ret.values = tensor->data();
  return ret;
}

}  // namespace tensorrt
}  // namespace kernel
}  // namespace infrt