未验证 提交 86919910 编写于 作者: W Wilber 提交者: GitHub

Trt engine (#40649)

上级 e3b2a035
...@@ -82,9 +82,176 @@ TrtUniquePtr<nvinfer1::INetworkDefinition> ConstructNetwork( ...@@ -82,9 +82,176 @@ TrtUniquePtr<nvinfer1::INetworkDefinition> ConstructNetwork(
return network; return network;
} }
TrtUniquePtr<nvinfer1::INetworkDefinition> ConstructFCNetwork(
nvinfer1::IBuilder* builder, nvinfer1::Dims dims, bool is_static_shape) {
TrtUniquePtr<nvinfer1::INetworkDefinition> network;
if (is_static_shape) {
network.reset(builder->createNetworkV2(0U));
} else {
auto networkFlags =
1U << static_cast<uint32_t>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
network.reset(builder->createNetworkV2(networkFlags));
}
ITensor* data =
network->addInput(model_input, nvinfer1::DataType::kFLOAT, dims);
CHECK_NOTNULL(data);
nvinfer1::Weights kernel_weights;
kernel_weights.type = nvinfer1::DataType::kFLOAT;
kernel_weights.count = 7840;
std::vector<float> weight_data(kernel_weights.count);
for (size_t i = 0; i < weight_data.size(); ++i) {
weight_data[i] = i % 255 * 0.02f;
}
kernel_weights.values = weight_data.data();
auto* layer = network->addFullyConnected(
*data, 10, kernel_weights, nvinfer1::Weights{});
CHECK_NOTNULL(layer);
auto* out = layer->getOutput(0);
out->setName(model_output);
network->markOutput(*out);
return network;
}
TrtUniquePtr<nvinfer1::INetworkDefinition> ConstructConvNetwork(
nvinfer1::IBuilder* builder, nvinfer1::Dims dims, bool is_static_shape) {
TrtUniquePtr<nvinfer1::INetworkDefinition> network;
if (is_static_shape) {
network.reset(builder->createNetworkV2(0U));
} else {
auto networkFlags =
1U << static_cast<uint32_t>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
network.reset(builder->createNetworkV2(networkFlags));
}
ITensor* data =
network->addInput(model_input, nvinfer1::DataType::kFLOAT, dims);
CHECK_NOTNULL(data);
nvinfer1::Weights kernel_weights, bias_weights;
kernel_weights.type = nvinfer1::DataType::kFLOAT;
bias_weights.type = nvinfer1::DataType::kFLOAT;
kernel_weights.count = 81;
bias_weights.count = 3;
std::vector<float> weight_data(kernel_weights.count);
for (size_t i = 0; i < weight_data.size(); ++i) {
weight_data[i] = i * 0.02f;
}
std::vector<float> bias_data(bias_weights.count);
for (size_t i = 0; i < bias_data.size(); ++i) {
bias_data[i] = i * 0.5f;
}
kernel_weights.values = weight_data.data();
bias_weights.values = bias_data.data();
nvinfer1::Dims ksize;
ksize.nbDims = 2;
ksize.d[0] = 3;
ksize.d[1] = 3;
auto* layer =
network->addConvolutionNd(*data, 3, ksize, kernel_weights, bias_weights);
CHECK_NOTNULL(layer);
auto* out = layer->getOutput(0);
out->setName(model_output);
network->markOutput(*out);
return network;
}
// sigmoid(x) = 1 / (1 + exp(-x)) // sigmoid(x) = 1 / (1 + exp(-x))
inline float sigmoid(float x) { return 1.f / (1.f + exp(-1 * x)); } inline float sigmoid(float x) { return 1.f / (1.f + exp(-1 * x)); }
TEST(trt, run_fc_static) {
TrtEngine engine(0);
auto net = ConstructFCNetwork(
engine.GetTrtBuilder(), nvinfer1::Dims3{1, 28, 28}, true);
BuildOptions build_options;
build_options.max_batch = 4;
build_options.workspace = 1024;
engine.Build(std::move(net), build_options);
InferenceOptions inference_options;
inference_options.batch = 1;
phi::GPUPlace place;
phi::GPUContext context;
context.PartialInitWithoutAllocator();
context.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place, context.stream())
.get());
context.PartialInitWithAllocator();
phi::DenseTensorMeta meta(
phi::DataType::FLOAT32,
phi::make_ddim({inference_options.batch, 1, 28, 28}));
phi::DenseTensor input;
input.set_meta(meta);
context.Alloc<float>(&input, input.numel() * sizeof(float));
std::vector<float> host_data(inference_options.batch * 1 * 28 * 28, 0);
for (size_t i = 0; i < host_data.size(); ++i) {
host_data[i] = i % 100 * 0.016f;
}
paddle::memory::Copy(place,
input.data<float>(),
phi::CPUPlace(),
host_data.data(),
sizeof(float) * host_data.size(),
context.stream());
std::unordered_map<std::string, phi::DenseTensor*> inputs;
inputs.emplace(std::make_pair(model_input, &input));
engine.PrepareOutputHandle("output_0");
engine.SetUpInference(inference_options, inputs);
engine.GetEngineInfo();
engine.Run(context);
cudaStreamSynchronize(context.stream());
}
TEST(trt, run_conv_static) {
TrtEngine engine(0);
auto net = ConstructConvNetwork(
engine.GetTrtBuilder(), nvinfer1::Dims3{3, 28, 28}, true);
BuildOptions build_options;
build_options.max_batch = 4;
build_options.workspace = 1024;
engine.Build(std::move(net), build_options);
InferenceOptions inference_options;
inference_options.batch = 1;
phi::GPUPlace place;
phi::GPUContext context;
context.PartialInitWithoutAllocator();
context.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place, context.stream())
.get());
context.PartialInitWithAllocator();
phi::DenseTensorMeta meta(
phi::DataType::FLOAT32,
phi::make_ddim({inference_options.batch, 3, 28, 28}));
phi::DenseTensor input;
input.set_meta(meta);
context.Alloc<float>(&input, input.numel() * sizeof(float));
std::vector<float> host_data(inference_options.batch * 3 * 28 * 28, 0);
for (size_t i = 0; i < host_data.size(); ++i) {
host_data[i] = i % 100 * 0.016f;
}
paddle::memory::Copy(place,
input.data<float>(),
phi::CPUPlace(),
host_data.data(),
sizeof(float) * host_data.size(),
context.stream());
std::unordered_map<std::string, phi::DenseTensor*> inputs;
inputs.emplace(std::make_pair(model_input, &input));
engine.PrepareOutputHandle("output_0");
engine.SetUpInference(inference_options, inputs);
engine.GetEngineInfo();
engine.Run(context);
cudaStreamSynchronize(context.stream());
}
TEST(trt, run_static) { TEST(trt, run_static) {
TrtEngine static_trt_engine(0); TrtEngine static_trt_engine(0);
auto net = ConstructNetwork( auto net = ConstructNetwork(
......
...@@ -142,9 +142,6 @@ mlir::Type InfrtDialect::parseType(::mlir::DialectAsmParser &parser) const { ...@@ -142,9 +142,6 @@ mlir::Type InfrtDialect::parseType(::mlir::DialectAsmParser &parser) const {
return infrt::DenseTensorListType::get(parser.getContext()); return infrt::DenseTensorListType::get(parser.getContext());
} }
if (keyword == "dense_tensor_map") {
return DenseTensorMapType::get(parser.getContext());
}
// Todo: parse other type // Todo: parse other type
return mlir::Type(); return mlir::Type();
} }
...@@ -181,6 +178,7 @@ void InfrtDialect::printType(::mlir::Type type, ...@@ -181,6 +178,7 @@ void InfrtDialect::printType(::mlir::Type type,
if (type.isa<infrt::DenseTensorListType>()) { if (type.isa<infrt::DenseTensorListType>()) {
os << "tensor_list"; os << "tensor_list";
return;
} }
// print DenseTensorType, for example: !infrt.dense_tensor<CPU, FP32, NCHW> // print DenseTensorType, for example: !infrt.dense_tensor<CPU, FP32, NCHW>
if (type.isa<DenseTensorMapType>()) { if (type.isa<DenseTensorMapType>()) {
......
...@@ -60,6 +60,39 @@ def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> { ...@@ -60,6 +60,39 @@ def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> {
let results = (outs DenseTensor:$output); let results = (outs DenseTensor:$output);
} }
def TRT_FullyConnectedOp : TRT_Op<"FullyConnected", [NoSideEffect]> {
let summary = "TensorRT IFullyConnectedLayer";
let description = [{
TensorRT IFullyConnectedLayer
}];
let arguments = (ins
DenseTensor:$input_tensor,
DenseTensor:$kernel_weights,
DenseTensor:$bias_weights,
SI32Attr:$out_channel_num
);
let results = (outs
DenseTensor:$output_tensor
);
}
def TRT_ConvolutionOp : TRT_Op<"Convolution", [NoSideEffect]> {
let summary = "TensorRT IConvolutionLayer";
let description = [{
TensorRT IConvolutionLayer
}];
let arguments = (ins
DenseTensor:$input_tensor,
DenseTensor:$kernel_weights,
DenseTensor:$bias_weights,
SI32Attr:$out_channel_num,
I32ArrayAttr:$kernel_size
);
let results = (outs
DenseTensor:$output_tensor
);
}
def TRT_ElementWiseOp : TRT_Op<"ElementWise", [NoSideEffect]> { def TRT_ElementWiseOp : TRT_Op<"ElementWise", [NoSideEffect]> {
let summary = "TensorRT IElementWiseLayer"; let summary = "TensorRT IElementWiseLayer";
let description = [{ let description = [{
......
...@@ -298,14 +298,21 @@ bool MlirToRuntimeTranslator::EmitGeneralOp( ...@@ -298,14 +298,21 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(
// add a naive implement. // add a naive implement.
for (int i = 0, e = op->getNumOperands(); i < e; ++i) { for (int i = 0, e = op->getNumOperands(); i < e; ++i) {
auto operand = op->getOperand(i); auto operand = op->getOperand(i);
Value* arg_value{nullptr};
if (operand.isa<mlir::BlockArgument>()) { if (operand.isa<mlir::BlockArgument>()) {
mlir::BlockArgument arg = operand.dyn_cast<mlir::BlockArgument>(); mlir::BlockArgument arg = operand.dyn_cast<mlir::BlockArgument>();
Value* arg_value = GetValue(arg); arg_value = GetValue(arg);
if (arg_value->is_type<phi::DenseTensor>()) { } else {
impl_->runtime->FeedInArgs( arg_value = GetValue(operand);
std::make_pair(std::to_string(i), ValueRef(arg_value))); if (!arg_value) {
auto upstream_op = operand.getDefiningOp();
arg_value = GetOpResult(upstream_op);
} }
} }
if (arg_value->is_type<phi::DenseTensor>()) {
impl_->runtime->FeedInArgs(
std::make_pair(std::to_string(i), ValueRef(arg_value)));
}
} }
#else #else
CHECK(false) << "should not reach here"; CHECK(false) << "should not reach here";
......
...@@ -146,8 +146,8 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) { ...@@ -146,8 +146,8 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) {
// TensorList related methods. // TensorList related methods.
#ifdef INFRT_WITH_PHI #ifdef INFRT_WITH_PHI
registry->AddKernel("dt.tensor_list_get_tensor", registry->AddKernelWithAttrs(
INFRT_KERNEL(TensorListGetTensor)); "dt.tensor_list_get_tensor", INFRT_KERNEL(TensorListGetTensor), {"id"});
registry->AddKernel("dt.tensor_list_get_size", registry->AddKernel("dt.tensor_list_get_size",
INFRT_KERNEL(TensorListGetSize)); INFRT_KERNEL(TensorListGetSize));
#endif #endif
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <NvInfer.h>
#include <NvInferRuntime.h>
#include <NvInferRuntimeCommon.h>
#include "glog/logging.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
namespace infrt {
namespace kernel {
namespace tensorrt {
static nvinfer1::DataType TensorTypeToWeightType(phi::DataType tensor_type) {
switch (tensor_type) {
case phi::DataType::FLOAT32:
return nvinfer1::DataType::kFLOAT;
case phi::DataType::INT32:
return nvinfer1::DataType::kINT32;
case phi::DataType::FLOAT16:
return nvinfer1::DataType::kHALF;
default:
llvm_unreachable("should not reach here");
}
}
static nvinfer1::Dims ArrayAttrToNvDims(const mlir::ArrayAttr& int_array_attr) {
nvinfer1::Dims dims;
dims.nbDims = int_array_attr.size();
CHECK(!int_array_attr.empty());
CHECK(int_array_attr[0].getType().isIntOrIndex());
for (int i = 0; i < dims.nbDims; ++i) {
dims.d[i] = int_array_attr[i].cast<mlir::IntegerAttr>().getInt();
}
return dims;
}
static nvinfer1::Weights TensorToWeights(phi::DenseTensor* tensor) {
CHECK_NOTNULL(tensor);
nvinfer1::Weights ret;
ret.type = TensorTypeToWeightType(tensor->dtype());
ret.count = tensor->numel();
ret.values = tensor->data();
return ret;
}
} // namespace tensorrt
} // namespace kernel
} // namespace infrt
...@@ -21,13 +21,19 @@ ...@@ -21,13 +21,19 @@
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "paddle/infrt/kernel/tensorrt/trt_helper.h"
#include "paddle/infrt/kernel/tensorrt/trt_layers.h"
#include "paddle/infrt/backends/tensorrt/trt_engine.h" #include "paddle/infrt/backends/tensorrt/trt_engine.h"
#include "paddle/infrt/backends/tensorrt/trt_options.h" #include "paddle/infrt/backends/tensorrt/trt_options.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "paddle/infrt/host_context/symbol_table.h" #include "paddle/infrt/host_context/symbol_table.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace infrt { namespace infrt {
...@@ -35,8 +41,7 @@ namespace kernel { ...@@ -35,8 +41,7 @@ namespace kernel {
namespace tensorrt { namespace tensorrt {
::infrt::backends::tensorrt::TrtEngine CreateTrtEngine( ::infrt::backends::tensorrt::TrtEngine CreateTrtEngine(
MlirOperationWithInfrtSymbol MlirOperationWithInfrtSymbol create_engine_op) {
create_engine_op /*, input_tensors, output_tensors, weights*/) {
// TODO(wilber): The device_id needs to get from mlir. // TODO(wilber): The device_id needs to get from mlir.
int device_id = 0; int device_id = 0;
backends::tensorrt::TrtEngine engine(device_id); backends::tensorrt::TrtEngine engine(device_id);
...@@ -51,6 +56,7 @@ namespace tensorrt { ...@@ -51,6 +56,7 @@ namespace tensorrt {
// TODO(wilber): The build option shoule be fiiled from mlir info. // TODO(wilber): The build option shoule be fiiled from mlir info.
backends::tensorrt::BuildOptions options; backends::tensorrt::BuildOptions options;
options.max_batch = 4; options.max_batch = 4;
options.workspace = 1024;
// Parse mlir Region which only has one block. // Parse mlir Region which only has one block.
mlir::Operation& operation = *create_engine_op.operation; mlir::Operation& operation = *create_engine_op.operation;
...@@ -62,8 +68,9 @@ namespace tensorrt { ...@@ -62,8 +68,9 @@ namespace tensorrt {
auto& region = operation.getRegion(0); auto& region = operation.getRegion(0);
auto& block = region.getBlocks().front(); auto& block = region.getBlocks().front();
llvm::DenseMap<mlir::Value, nvinfer1::ITensor*> map_info;
std::unordered_map<std::string, phi::DenseTensor*> trt_bind_inputs; std::unordered_map<std::string, phi::DenseTensor*> trt_bind_inputs;
ValueToITensorMap value_to_trt_tensor_map;
ValueToTensorMap value_to_tensor_map;
for (auto index_operand : llvm::enumerate(operation.getOperands())) { for (auto index_operand : llvm::enumerate(operation.getOperands())) {
mlir::Value operand = index_operand.value(); mlir::Value operand = index_operand.value();
...@@ -73,69 +80,72 @@ namespace tensorrt { ...@@ -73,69 +80,72 @@ namespace tensorrt {
auto* v = symbol_table->GetValue(std::to_string(idx)); auto* v = symbol_table->GetValue(std::to_string(idx));
CHECK_NOTNULL(v); CHECK_NOTNULL(v);
auto* t = &v->get<phi::DenseTensor>(); auto* t = &v->get<phi::DenseTensor>();
trt_bind_inputs[input_name] = t; value_to_tensor_map[operand] = t;
// TODO(wilber): get input info from mlir. // TODO(wilber): get input info from mlir.
// TODO(wilber): input dims, now only support static_shape, and just remove // TODO(wilber): input dims, now only support static_shape, and just remove
// the first dimension. // the first dimension. If the first dim is not -1, maybe we can pass the
// origin dims.
// TODO(wilber): now only suppot float input. // TODO(wilber): now only suppot float input.
nvinfer1::Dims dims;
dims.nbDims = t->dims().size() - 1;
for (int i = 0; i < dims.nbDims; ++i) {
dims.d[i] = t->dims()[i + 1];
}
auto* in =
network->addInput(input_name.c_str(), nvinfer1::DataType::kFLOAT, dims);
map_info[operand] = in;
}
// TODO(wilber): Find a way to add layer. if (operand.isa<mlir::BlockArgument>()) {
for (auto& inner_op : block.without_terminator()) { // TODO(wilber): A trick: the weights are CPU tensor and inputs are GPU
if (inner_op.getName().getStringRef() == "trt.Activation") { // tensor, so we treat all GPU tensors as inputs to trt.
trt::ActivationOp act_op = llvm::dyn_cast<trt::ActivationOp>(inner_op); if (t->place().GetType() == phi::AllocationType::GPU) {
auto in_arg = act_op.getOperand(); trt_bind_inputs[input_name] = t;
if (!map_info.count(in_arg)) { nvinfer1::Dims dims;
CHECK(false) << "map_info not has in_arg."; dims.nbDims = t->dims().size() - 1;
for (int i = 0; i < dims.nbDims; ++i) {
dims.d[i] = t->dims()[i + 1];
}
auto* in = network->addInput(
input_name.c_str(), nvinfer1::DataType::kFLOAT, dims);
value_to_trt_tensor_map[operand] = in;
} }
nvinfer1::ActivationType act_type = } else {
static_cast<nvinfer1::ActivationType>(act_op.activation_type()); // TODO(wilber): Replace with the op name that generates the weights.
auto* act_layer = network->addActivation(*map_info[in_arg], act_type); if (operand.getDefiningOp()->getName().getStringRef() !=
act_layer->setAlpha(act_op.alpha().convertToFloat()); "phi_dt.create_dense_tensor.cpu") {
act_layer->setBeta(act_op.beta().convertToFloat()); trt_bind_inputs[input_name] = t;
for (size_t i = 0; i < act_op->getNumResults(); ++i) { nvinfer1::Dims dims;
nvinfer1::ITensor* act_out_tensor = act_layer->getOutput(i); dims.nbDims = t->dims().size() - 1;
mlir::Value act_out = act_op->getResult(i); for (int i = 0; i < dims.nbDims; ++i) {
map_info[act_out] = act_out_tensor; dims.d[i] = t->dims()[i + 1];
}
auto* in = network->addInput(
input_name.c_str(), nvinfer1::DataType::kFLOAT, dims);
value_to_trt_tensor_map[operand] = in;
} }
} }
// if (inner_op.getName().getStringRef() == "trt.Constant") {
// trt::ConstantOp op = llvm::dyn_cast<trt::ConstantOp>(inner_op);
// mlir::Value op_out = op.getResult();
// std::vector<float> weight_data{1};
// auto* layer = network->addConstant(nvinfer1::Dims2(1, 1),
// nvinfer1::Weights{nvinfer1::DataType::kFLOAT, weight_data.data(), 1});
// auto* op_out_tenor = layer->getOutput(0);
// map_info[op_out] = op_out_tenor;
// }
} }
for (auto& inner_op : block.without_terminator()) {
for (mlir::Value v : inner_op.getResults()) { // TODO(wilber): Find a way to add layer.
for (mlir::Operation* user : v.getUsers()) { for (auto& operation : block.without_terminator()) {
if (user->getName().getStringRef() == "infrt.return") { if (trt::ActivationOp op = llvm::dyn_cast<trt::ActivationOp>(operation)) {
if (!map_info.count(v)) { ActivationFunc(
CHECK(false) << "map_info not has value"; op, network.get(), value_to_trt_tensor_map, value_to_tensor_map);
} } else if (trt::FullyConnectedOp op =
network->markOutput(*map_info[v]); llvm::dyn_cast<trt::FullyConnectedOp>(operation)) {
} FcFunc(op, network.get(), value_to_trt_tensor_map, value_to_tensor_map);
} } else if (trt::ConvolutionOp op =
llvm::dyn_cast<trt::ConvolutionOp>(operation)) {
ConvFunc(op, network.get(), value_to_trt_tensor_map, value_to_tensor_map);
} else {
CHECK(false) << "not supported operation.";
} }
} }
// std::unordered_map<std::string, phi::DenseTensor*> trt_bind_outputs;
mlir::Operation* ret = block.getTerminator(); for (auto index_operand :
for (unsigned int i = 0; i < ret->getNumOperands(); ++i) { llvm::enumerate(block.getTerminator()->getOperands())) {
mlir::Value arg = ret->getOperand(i); mlir::Value arg = index_operand.value();
CHECK(map_info.count(arg)); CHECK(value_to_trt_tensor_map.count(arg));
map_info[arg]->setName(("output_" + std::to_string(i)).c_str()); // TODO(wilber): A trick that we name trt output tensor's name as output_0,
// output_1, ...
value_to_trt_tensor_map[arg]->setName(
("output_" + std::to_string(index_operand.index())).c_str());
network->markOutput(*value_to_trt_tensor_map[arg]);
} }
for (int i = 0; i < network->getNbOutputs(); ++i) { for (int i = 0; i < network->getNbOutputs(); ++i) {
engine.PrepareOutputHandle(network->getOutput(i)->getName()); engine.PrepareOutputHandle(network->getOutput(i)->getName());
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <NvInfer.h>
#include <mlir/IR/Operation.h>
#include <string>
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "paddle/infrt/kernel/tensorrt/trt_helper.h"
#include "paddle/phi/core/dense_tensor.h"
namespace infrt {
namespace kernel {
namespace tensorrt {
using ValueToTensorMap = llvm::DenseMap<mlir::Value, phi::DenseTensor*>;
using ValueToITensorMap = llvm::DenseMap<mlir::Value, nvinfer1::ITensor*>;
inline void ActivationFunc(
trt::ActivationOp& act_op, // NOLINT
nvinfer1::INetworkDefinition* network,
ValueToITensorMap& value_to_trt_tensor_map, // NOLINT
ValueToTensorMap& value_to_tensor_map) { // NOLINT
auto in_arg = act_op.getOperand();
CHECK(value_to_trt_tensor_map.count(in_arg))
<< "value_to_trt_tensor_map not has in_arg.";
nvinfer1::ActivationType act_type =
static_cast<nvinfer1::ActivationType>(act_op.activation_type());
auto* act_layer =
network->addActivation(*value_to_trt_tensor_map[in_arg], act_type);
act_layer->setAlpha(act_op.alpha().convertToFloat());
act_layer->setBeta(act_op.beta().convertToFloat());
for (size_t i = 0; i < act_op->getNumResults(); ++i) {
nvinfer1::ITensor* act_out_tensor = act_layer->getOutput(i);
mlir::Value act_out = act_op->getResult(i);
value_to_trt_tensor_map[act_out] = act_out_tensor;
}
}
inline void ConvFunc(trt::ConvolutionOp& op, // NOLINT
nvinfer1::INetworkDefinition* network,
ValueToITensorMap& value_to_trt_tensor_map, // NOLINT
ValueToTensorMap& value_to_tensor_map) { // NOLINT
mlir::Value input_tensor_repr = op.input_tensor();
int out_channel_num = op.out_channel_num();
auto size_attrs = op.kernel_size();
nvinfer1::Dims dims = ArrayAttrToNvDims(size_attrs);
auto kernel_weights =
TensorToWeights(value_to_tensor_map[op.kernel_weights()]);
auto bias_weights = TensorToWeights(value_to_tensor_map[op.bias_weights()]);
auto* layer =
network->addConvolutionNd(*value_to_trt_tensor_map[input_tensor_repr],
out_channel_num,
dims,
kernel_weights,
bias_weights);
CHECK_NOTNULL(layer);
mlir::Value out_repr = op.output_tensor();
nvinfer1::ITensor* out_tensor = layer->getOutput(0);
value_to_trt_tensor_map[out_repr] = out_tensor;
}
inline void FcFunc(trt::FullyConnectedOp& op, // NOLINT
nvinfer1::INetworkDefinition* network,
ValueToITensorMap& value_to_trt_tensor_map, // NOLINT
ValueToTensorMap& value_to_tensor_map) { // NOLINT
mlir::Value input_tensor_repr = op.input_tensor();
CHECK(value_to_trt_tensor_map.count(input_tensor_repr));
auto kernel_weights =
TensorToWeights(value_to_tensor_map[op.kernel_weights()]);
auto bias_weights = TensorToWeights(value_to_tensor_map[op.bias_weights()]);
int out_channel_num = op.out_channel_num();
auto* layer =
network->addFullyConnected(*value_to_trt_tensor_map[input_tensor_repr],
out_channel_num,
kernel_weights,
bias_weights);
mlir::Value out_repr = op.output_tensor();
nvinfer1::ITensor* out_tensor = layer->getOutput(0);
value_to_trt_tensor_map[out_repr] = out_tensor;
}
} // namespace tensorrt
} // namespace kernel
} // namespace infrt
// RUN: infrtexec -i %s | FileCheck %s
// CHECK-LABEL: @run_trt
func @run_trt(%input_tensor : !infrt.dense_tensor<GPU, FP32, NCHW>, %kernel_weight : !infrt.dense_tensor<CPU, FP32, NCHW>, %kernel_bias : !infrt.dense_tensor<CPU, FP32, NCHW>, %gpu_ctx : !phi.context<GPU>) {
%a = "trt.create_engine"(%input_tensor, %kernel_weight, %kernel_bias) ({
%1 = "trt.Activation"(%input_tensor) {activation_type = 1 : si32, alpha = 1.0 : f32, beta = 6.0 : f32} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%2 = "trt.Convolution"(%input_tensor, %kernel_weight, %kernel_bias) {out_channel_num = 3 : si32, kernel_size = [3:i32, 3:i32]} : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
"infrt.return"(%1, %2) : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
}) : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !trt.engine
"trt.inspect_engine"(%a) {} : (!trt.engine) -> ()
%res = "trt.compute"(%a, %gpu_ctx) {} : (!trt.engine, !phi.context<GPU>) -> (!infrt.tensor_list)
%size = "dt.tensor_list_get_size"(%res) {} : (!infrt.tensor_list) -> (i32)
"infrt.print.i32"(%size) {} : (i32) -> ()
%ts0 = "dt.tensor_list_get_tensor"(%res) {id = 0 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor<GPU, FP32, NCHW>)
"phi_dt.print_tensor" (%ts0) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
%ts1 = "dt.tensor_list_get_tensor"(%res) {id = 1 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor<GPU, FP32, NCHW>)
"phi_dt.print_tensor" (%ts1) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
infrt.return
}
// CHECK-LABEL: @main
func @main() {
%gpu_ctx = "phi_dt.create_context.gpu" (): () -> !phi.context<GPU>
%cpu_ctx = "phi_dt.create_context.cpu" (): () -> !phi.context<CPU>
%input_tensor = "phi_dt.create_dense_tensor.gpu" (%gpu_ctx) {
precision=#infrt.precision<FP32>,
layout=#infrt.layout<NCHW>,
dims=[1:i64, 3:i64, 28:i64, 28:i64], lod=[0:i64]}: (!phi.context<GPU>) -> (!infrt.dense_tensor<GPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%input_tensor) {value=[3.8:f32, 2.4:f32, 1.3:f32]} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
// "phi_dt.print_tensor" (%input_tensor) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
%kernel_weight = "phi_dt.create_dense_tensor.cpu"(%cpu_ctx) {
precision=#infrt.precision<FP32>,
layout=#infrt.layout<NCHW>,
dims=[3:i64, 3:i64, 3:i64, 3:i64], lod=[0:i64]} : (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%kernel_weight) {value=[1.:f32, 2.:f32, 3.:f32, 4.:f32, 5.:f32, 6.:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
// "phi_dt.print_tensor" (%kernel_weight) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%kernel_bias = "phi_dt.create_dense_tensor.cpu"(%cpu_ctx) {
precision=#infrt.precision<FP32>,
layout=#infrt.layout<NCHW>,
dims=[3:i64], lod=[0:i64]} : (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%kernel_bias) {value=[1.:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
// "phi_dt.print_tensor" (%kernel_bias) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
infrt.call @run_trt(%input_tensor, %kernel_weight, %kernel_bias, %gpu_ctx) : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>, !phi.context<GPU>) -> ()
infrt.return
}
// RUN: infrtexec -i %s | FileCheck %s
// CHECK-LABEL: @main
func @main() {
%ctx = "phi_dt.create_context.gpu" (): () -> !phi.context<GPU>
%cpu_ctx = "phi_dt.create_context.cpu" (): () -> !phi.context<CPU>
%input_tensor = "phi_dt.create_dense_tensor.gpu" (%ctx) {
precision=#infrt.precision<FP32>,
layout=#infrt.layout<NCHW>,
dims=[1:i64, 3:i64, 1:i64, 1:i64], lod=[1:i64]}: (!phi.context<GPU>) -> (!infrt.dense_tensor<GPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%input_tensor) {value=[3.8:f32, 2.4:f32, 1.3:f32]} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
//"phi_dt.print_tensor" (%input_tensor) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
%kernel_weight = "phi_dt.create_dense_tensor.cpu"(%cpu_ctx) {
precision=#infrt.precision<FP32>,
layout=#infrt.layout<NCHW>,
dims=[2:i64, 3:i64], lod=[1:i64]} : (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%kernel_weight) {value=[1.:f32, 2.:f32, 3.:f32, 4.:f32, 5.:f32, 6.:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
//"phi_dt.print_tensor" (%kernel_weight) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%kernel_bias = "phi_dt.create_dense_tensor.cpu"(%cpu_ctx) {
precision=#infrt.precision<FP32>,
layout=#infrt.layout<NCHW>,
dims=[2:i64], lod=[1:i64]} : (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%kernel_bias) {value=[1.:f32, 2.:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
//"phi_dt.print_tensor" (%kernel_bias) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%engine = "trt.create_engine"(%input_tensor, %kernel_weight, %kernel_bias) ({
%1 = "trt.Activation"(%input_tensor) {activation_type = 1 : si32, alpha = 1.0 : f32, beta = 6.0 : f32} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%2 = "trt.FullyConnected"(%input_tensor, %kernel_weight, %kernel_bias) {out_channel_num = 2 : si32} : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
"infrt.return"(%1, %2) : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
}) : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !trt.engine
%res = "trt.compute"(%engine, %ctx) {} : (!trt.engine, !phi.context<GPU>) -> (!infrt.tensor_list)
%size = "dt.tensor_list_get_size"(%res) {} : (!infrt.tensor_list) -> (i32)
"infrt.print.i32"(%size) {} : (i32) -> ()
%ts0 = "dt.tensor_list_get_tensor"(%res) {id = 0 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor<GPU, FP32, NCHW>)
"phi_dt.print_tensor" (%ts0) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
%ts1 = "dt.tensor_list_get_tensor"(%res) {id = 1 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor<GPU, FP32, NCHW>)
"phi_dt.print_tensor" (%ts1) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
infrt.return
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册