trt_kernels.cc 7.0 KB
Newer Older
W
Wilber 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/infrt/kernel/tensorrt/trt_kernels.h"
#include <string>
W
Wilber 已提交
17
#include <unordered_set>
W
Wilber 已提交
18 19 20 21 22 23 24
#include "NvInfer.h"
#include "NvInferRuntime.h"
#include "NvInferRuntimeCommon.h"
#include "glog/logging.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
W
Wilber 已提交
25
#include "mlir/IR/BuiltinAttributes.h"
W
Wilber 已提交
26 27 28
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
W
Wilber 已提交
29 30 31 32

#include "paddle/infrt/kernel/tensorrt/trt_helper.h"
#include "paddle/infrt/kernel/tensorrt/trt_layers.h"

W
Wilber 已提交
33 34 35 36
#include "paddle/infrt/backends/tensorrt/trt_engine.h"
#include "paddle/infrt/backends/tensorrt/trt_options.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "paddle/infrt/host_context/symbol_table.h"
W
Wilber 已提交
37
#include "paddle/phi/common/place.h"
W
Wilber 已提交
38 39 40 41 42 43 44
#include "paddle/phi/core/dense_tensor.h"

namespace infrt {
namespace kernel {
namespace tensorrt {

::infrt::backends::tensorrt::TrtEngine CreateTrtEngine(
W
Wilber 已提交
45
    MlirOperationWithInfrtSymbol create_engine_op) {
W
Wilber 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59
  // TODO(wilber): The device_id needs to get from mlir.
  int device_id = 0;
  backends::tensorrt::TrtEngine engine(device_id);

  auto* builder = engine.GetTrtBuilder();
  // TODO(wilber): How to process weights?
  backends::tensorrt::TrtUniquePtr<nvinfer1::INetworkDefinition> network;
  // TODO(wilber): static_shape or dynamic_shape network? The code is just
  // static_shape test.
  network.reset(builder->createNetworkV2(0));

  // TODO(wilber): The build option shoule be fiiled from mlir info.
  backends::tensorrt::BuildOptions options;
  options.max_batch = 4;
W
Wilber 已提交
60
  options.workspace = 1024;
W
Wilber 已提交
61 62 63 64 65 66 67 68 69 70 71

  // Parse mlir Region which only has one block.
  mlir::Operation& operation = *create_engine_op.operation;
  auto* symbol_table = create_engine_op.symbol_table;
  CHECK_NOTNULL(symbol_table);

  unsigned int num_regions = operation.getNumRegions();
  CHECK_EQ(num_regions, 1U) << "only support one region case.";
  auto& region = operation.getRegion(0);
  auto& block = region.getBlocks().front();

W
Wilber 已提交
72
  std::unordered_map<std::string, ::phi::DenseTensor*> trt_bind_inputs;
W
Wilber 已提交
73 74
  ValueToITensorMap value_to_trt_tensor_map;
  ValueToTensorMap value_to_tensor_map;
W
Wilber 已提交
75 76 77 78 79 80 81 82

  for (auto index_operand : llvm::enumerate(operation.getOperands())) {
    mlir::Value operand = index_operand.value();
    size_t idx = index_operand.index();

    const std::string input_name = "input_" + std::to_string(idx);
    auto* v = symbol_table->GetValue(std::to_string(idx));
    CHECK_NOTNULL(v);
W
Wilber 已提交
83
    auto* t = &v->get<::phi::DenseTensor>();
W
Wilber 已提交
84 85
    value_to_tensor_map[operand] = t;

W
Wilber 已提交
86
    // TODO(wilber): get input info from mlir.
W
Wilber 已提交
87

W
Wilber 已提交
88
    // TODO(wilber): input dims, now only support static_shape, and just remove
W
Wilber 已提交
89 90 91
    // the first dimension. If the first dim is not -1, maybe we can pass the
    // origin dims.

W
Wilber 已提交
92 93
    // TODO(wilber): now only suppot float input.

W
Wilber 已提交
94 95 96
    if (operand.isa<mlir::BlockArgument>()) {
      // TODO(wilber): A trick: the weights are CPU tensor and inputs are GPU
      // tensor, so we treat all GPU tensors as inputs to trt.
W
Wilber 已提交
97
      if (t->place().GetType() == ::phi::AllocationType::GPU) {
W
Wilber 已提交
98 99 100 101 102 103 104 105 106
        trt_bind_inputs[input_name] = t;
        nvinfer1::Dims dims;
        dims.nbDims = t->dims().size() - 1;
        for (int i = 0; i < dims.nbDims; ++i) {
          dims.d[i] = t->dims()[i + 1];
        }
        auto* in = network->addInput(
            input_name.c_str(), nvinfer1::DataType::kFLOAT, dims);
        value_to_trt_tensor_map[operand] = in;
W
Wilber 已提交
107
      }
W
Wilber 已提交
108 109
    } else {
      // TODO(wilber): Replace with the op name that generates the weights.
W
Wilber 已提交
110
      std::unordered_set<std::string> weight_flags{
111 112 113 114
          "phi_dt.tensor_map_get_tensor",
          "phi_dt.create_dense_tensor.cpu",
          "phi_dt.create_inited_dense_tensor.cpu.f32",
          "phi_dt.create_host_inited_dense_tensor.f32"};
W
Wilber 已提交
115 116
      if (!weight_flags.count(
              operand.getDefiningOp()->getName().getStringRef().str())) {
W
Wilber 已提交
117 118 119 120 121 122 123 124 125
        trt_bind_inputs[input_name] = t;
        nvinfer1::Dims dims;
        dims.nbDims = t->dims().size() - 1;
        for (int i = 0; i < dims.nbDims; ++i) {
          dims.d[i] = t->dims()[i + 1];
        }
        auto* in = network->addInput(
            input_name.c_str(), nvinfer1::DataType::kFLOAT, dims);
        value_to_trt_tensor_map[operand] = in;
W
Wilber 已提交
126 127 128
      }
    }
  }
W
Wilber 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142

  // TODO(wilber): Find a way to add layer.
  for (auto& operation : block.without_terminator()) {
    if (trt::ActivationOp op = llvm::dyn_cast<trt::ActivationOp>(operation)) {
      ActivationFunc(
          op, network.get(), value_to_trt_tensor_map, value_to_tensor_map);
    } else if (trt::FullyConnectedOp op =
                   llvm::dyn_cast<trt::FullyConnectedOp>(operation)) {
      FcFunc(op, network.get(), value_to_trt_tensor_map, value_to_tensor_map);
    } else if (trt::ConvolutionOp op =
                   llvm::dyn_cast<trt::ConvolutionOp>(operation)) {
      ConvFunc(op, network.get(), value_to_trt_tensor_map, value_to_tensor_map);
    } else {
      CHECK(false) << "not supported operation.";
W
Wilber 已提交
143 144
    }
  }
W
Wilber 已提交
145 146 147 148 149 150 151 152 153 154

  for (auto index_operand :
       llvm::enumerate(block.getTerminator()->getOperands())) {
    mlir::Value arg = index_operand.value();
    CHECK(value_to_trt_tensor_map.count(arg));
    // TODO(wilber): A trick that we name trt output tensor's name as output_0,
    // output_1, ...
    value_to_trt_tensor_map[arg]->setName(
        ("output_" + std::to_string(index_operand.index())).c_str());
    network->markOutput(*value_to_trt_tensor_map[arg]);
W
Wilber 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
  }
  for (int i = 0; i < network->getNbOutputs(); ++i) {
    engine.PrepareOutputHandle(network->getOutput(i)->getName());
  }

  VLOG(3) << "trt engine build start.";
  engine.Build(std::move(network), options);
  VLOG(3) << "trt engine build done.";

  // TODO(wilber): get inference options from mlir.
  backends::tensorrt::InferenceOptions inference_options;
  inference_options.batch = 1;
  // TODO(wilber): bind trt input/output tensors.
  engine.SetUpInference(inference_options, trt_bind_inputs);
  return engine;
}

void PrintTrtLayer(backends::tensorrt::TrtEngine* engine) {
  engine->GetEngineInfo();
}

W
Wilber 已提交
176 177
std::vector<::phi::DenseTensor*> TrtEngineCompute(
    backends::tensorrt::TrtEngine* engine, const ::phi::GPUContext& context) {
W
Wilber 已提交
178
  engine->Run(context);
W
Wilber 已提交
179
  std::vector<::phi::DenseTensor*> res;
W
Wilber 已提交
180 181 182 183 184 185 186 187 188
  for (size_t i = 0; i < engine->GetOutputNum(); ++i) {
    res.push_back(engine->GetOutput("output_" + std::to_string(i)));
  }
  return res;
}

}  // namespace tensorrt
}  // namespace kernel
}  // namespace infrt