From 869199102e6b43aeb578f4d2dcbd21a0d53104cd Mon Sep 17 00:00:00 2001 From: Wilber Date: Fri, 18 Mar 2022 14:34:12 +0800 Subject: [PATCH] Trt engine (#40649) --- .../backends/tensorrt/test_trt_engine.cc | 167 ++++++++++++++++++ .../infrt/dialect/infrt/ir/infrt_dialect.cc | 4 +- paddle/infrt/dialect/tensorrt/trt_ops.td | 33 ++++ .../host_context/mlir_to_runtime_translate.cc | 15 +- paddle/infrt/kernel/tensor_kernels.cc | 4 +- paddle/infrt/kernel/tensorrt/trt_helper.h | 66 +++++++ paddle/infrt/kernel/tensorrt/trt_kernels.cc | 122 +++++++------ paddle/infrt/kernel/tensorrt/trt_layers.h | 104 +++++++++++ .../dialect/{ => tensorrt}/disabled_trt.mlir | 0 .../dialect/tensorrt/disabled_trt_conv.mlir | 54 ++++++ .../dialect/tensorrt/disabled_trt_fc.mlir | 46 +++++ 11 files changed, 550 insertions(+), 65 deletions(-) create mode 100644 paddle/infrt/kernel/tensorrt/trt_helper.h create mode 100644 paddle/infrt/kernel/tensorrt/trt_layers.h rename paddle/infrt/tests/dialect/{ => tensorrt}/disabled_trt.mlir (100%) create mode 100644 paddle/infrt/tests/dialect/tensorrt/disabled_trt_conv.mlir create mode 100644 paddle/infrt/tests/dialect/tensorrt/disabled_trt_fc.mlir diff --git a/paddle/infrt/backends/tensorrt/test_trt_engine.cc b/paddle/infrt/backends/tensorrt/test_trt_engine.cc index 0ab64dd51c..89dd3b0dc7 100644 --- a/paddle/infrt/backends/tensorrt/test_trt_engine.cc +++ b/paddle/infrt/backends/tensorrt/test_trt_engine.cc @@ -82,9 +82,176 @@ TrtUniquePtr ConstructNetwork( return network; } +TrtUniquePtr ConstructFCNetwork( + nvinfer1::IBuilder* builder, nvinfer1::Dims dims, bool is_static_shape) { + TrtUniquePtr network; + if (is_static_shape) { + network.reset(builder->createNetworkV2(0U)); + } else { + auto networkFlags = + 1U << static_cast( + nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + network.reset(builder->createNetworkV2(networkFlags)); + } + + ITensor* data = + network->addInput(model_input, nvinfer1::DataType::kFLOAT, dims); + CHECK_NOTNULL(data); + nvinfer1::Weights kernel_weights; + kernel_weights.type = nvinfer1::DataType::kFLOAT; + kernel_weights.count = 7840; + std::vector weight_data(kernel_weights.count); + for (size_t i = 0; i < weight_data.size(); ++i) { + weight_data[i] = i % 255 * 0.02f; + } + kernel_weights.values = weight_data.data(); + auto* layer = network->addFullyConnected( + *data, 10, kernel_weights, nvinfer1::Weights{}); + CHECK_NOTNULL(layer); + auto* out = layer->getOutput(0); + out->setName(model_output); + network->markOutput(*out); + return network; +} + +TrtUniquePtr ConstructConvNetwork( + nvinfer1::IBuilder* builder, nvinfer1::Dims dims, bool is_static_shape) { + TrtUniquePtr network; + if (is_static_shape) { + network.reset(builder->createNetworkV2(0U)); + } else { + auto networkFlags = + 1U << static_cast( + nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + network.reset(builder->createNetworkV2(networkFlags)); + } + + ITensor* data = + network->addInput(model_input, nvinfer1::DataType::kFLOAT, dims); + CHECK_NOTNULL(data); + nvinfer1::Weights kernel_weights, bias_weights; + kernel_weights.type = nvinfer1::DataType::kFLOAT; + bias_weights.type = nvinfer1::DataType::kFLOAT; + kernel_weights.count = 81; + bias_weights.count = 3; + std::vector weight_data(kernel_weights.count); + for (size_t i = 0; i < weight_data.size(); ++i) { + weight_data[i] = i * 0.02f; + } + std::vector bias_data(bias_weights.count); + for (size_t i = 0; i < bias_data.size(); ++i) { + bias_data[i] = i * 0.5f; + } + kernel_weights.values = weight_data.data(); + bias_weights.values = bias_data.data(); + nvinfer1::Dims ksize; + ksize.nbDims = 2; + ksize.d[0] = 3; + ksize.d[1] = 3; + auto* layer = + network->addConvolutionNd(*data, 3, ksize, kernel_weights, bias_weights); + CHECK_NOTNULL(layer); + auto* out = layer->getOutput(0); + out->setName(model_output); + network->markOutput(*out); + return network; +} + // sigmoid(x) = 1 / (1 + exp(-x)) inline float sigmoid(float x) { return 1.f / (1.f + exp(-1 * x)); } +TEST(trt, run_fc_static) { + TrtEngine engine(0); + auto net = ConstructFCNetwork( + engine.GetTrtBuilder(), nvinfer1::Dims3{1, 28, 28}, true); + BuildOptions build_options; + build_options.max_batch = 4; + build_options.workspace = 1024; + engine.Build(std::move(net), build_options); + + InferenceOptions inference_options; + inference_options.batch = 1; + + phi::GPUPlace place; + phi::GPUContext context; + context.PartialInitWithoutAllocator(); + context.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(place, context.stream()) + .get()); + context.PartialInitWithAllocator(); + + phi::DenseTensorMeta meta( + phi::DataType::FLOAT32, + phi::make_ddim({inference_options.batch, 1, 28, 28})); + phi::DenseTensor input; + input.set_meta(meta); + context.Alloc(&input, input.numel() * sizeof(float)); + std::vector host_data(inference_options.batch * 1 * 28 * 28, 0); + for (size_t i = 0; i < host_data.size(); ++i) { + host_data[i] = i % 100 * 0.016f; + } + paddle::memory::Copy(place, + input.data(), + phi::CPUPlace(), + host_data.data(), + sizeof(float) * host_data.size(), + context.stream()); + + std::unordered_map inputs; + inputs.emplace(std::make_pair(model_input, &input)); + engine.PrepareOutputHandle("output_0"); + engine.SetUpInference(inference_options, inputs); + engine.GetEngineInfo(); + engine.Run(context); + cudaStreamSynchronize(context.stream()); +} + +TEST(trt, run_conv_static) { + TrtEngine engine(0); + auto net = ConstructConvNetwork( + engine.GetTrtBuilder(), nvinfer1::Dims3{3, 28, 28}, true); + BuildOptions build_options; + build_options.max_batch = 4; + build_options.workspace = 1024; + engine.Build(std::move(net), build_options); + + InferenceOptions inference_options; + inference_options.batch = 1; + + phi::GPUPlace place; + phi::GPUContext context; + context.PartialInitWithoutAllocator(); + context.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(place, context.stream()) + .get()); + context.PartialInitWithAllocator(); + + phi::DenseTensorMeta meta( + phi::DataType::FLOAT32, + phi::make_ddim({inference_options.batch, 3, 28, 28})); + phi::DenseTensor input; + input.set_meta(meta); + context.Alloc(&input, input.numel() * sizeof(float)); + std::vector host_data(inference_options.batch * 3 * 28 * 28, 0); + for (size_t i = 0; i < host_data.size(); ++i) { + host_data[i] = i % 100 * 0.016f; + } + paddle::memory::Copy(place, + input.data(), + phi::CPUPlace(), + host_data.data(), + sizeof(float) * host_data.size(), + context.stream()); + + std::unordered_map inputs; + inputs.emplace(std::make_pair(model_input, &input)); + engine.PrepareOutputHandle("output_0"); + engine.SetUpInference(inference_options, inputs); + engine.GetEngineInfo(); + engine.Run(context); + cudaStreamSynchronize(context.stream()); +} + TEST(trt, run_static) { TrtEngine static_trt_engine(0); auto net = ConstructNetwork( diff --git a/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc b/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc index 8966ca13c2..f8d8f51474 100644 --- a/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc +++ b/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc @@ -142,9 +142,6 @@ mlir::Type InfrtDialect::parseType(::mlir::DialectAsmParser &parser) const { return infrt::DenseTensorListType::get(parser.getContext()); } - if (keyword == "dense_tensor_map") { - return DenseTensorMapType::get(parser.getContext()); - } // Todo: parse other type return mlir::Type(); } @@ -181,6 +178,7 @@ void InfrtDialect::printType(::mlir::Type type, if (type.isa()) { os << "tensor_list"; + return; } // print DenseTensorType, for example: !infrt.dense_tensor if (type.isa()) { diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.td b/paddle/infrt/dialect/tensorrt/trt_ops.td index 31b28a38e7..803a11ed5b 100755 --- a/paddle/infrt/dialect/tensorrt/trt_ops.td +++ b/paddle/infrt/dialect/tensorrt/trt_ops.td @@ -60,6 +60,39 @@ def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> { let results = (outs DenseTensor:$output); } +def TRT_FullyConnectedOp : TRT_Op<"FullyConnected", [NoSideEffect]> { + let summary = "TensorRT IFullyConnectedLayer"; + let description = [{ + TensorRT IFullyConnectedLayer + }]; + let arguments = (ins + DenseTensor:$input_tensor, + DenseTensor:$kernel_weights, + DenseTensor:$bias_weights, + SI32Attr:$out_channel_num + ); + let results = (outs + DenseTensor:$output_tensor + ); +} + +def TRT_ConvolutionOp : TRT_Op<"Convolution", [NoSideEffect]> { + let summary = "TensorRT IConvolutionLayer"; + let description = [{ + TensorRT IConvolutionLayer + }]; + let arguments = (ins + DenseTensor:$input_tensor, + DenseTensor:$kernel_weights, + DenseTensor:$bias_weights, + SI32Attr:$out_channel_num, + I32ArrayAttr:$kernel_size + ); + let results = (outs + DenseTensor:$output_tensor + ); +} + def TRT_ElementWiseOp : TRT_Op<"ElementWise", [NoSideEffect]> { let summary = "TensorRT IElementWiseLayer"; let description = [{ diff --git a/paddle/infrt/host_context/mlir_to_runtime_translate.cc b/paddle/infrt/host_context/mlir_to_runtime_translate.cc index 3d5cccb5c3..bcd44540b3 100644 --- a/paddle/infrt/host_context/mlir_to_runtime_translate.cc +++ b/paddle/infrt/host_context/mlir_to_runtime_translate.cc @@ -298,14 +298,21 @@ bool MlirToRuntimeTranslator::EmitGeneralOp( // add a naive implement. for (int i = 0, e = op->getNumOperands(); i < e; ++i) { auto operand = op->getOperand(i); + Value* arg_value{nullptr}; if (operand.isa()) { mlir::BlockArgument arg = operand.dyn_cast(); - Value* arg_value = GetValue(arg); - if (arg_value->is_type()) { - impl_->runtime->FeedInArgs( - std::make_pair(std::to_string(i), ValueRef(arg_value))); + arg_value = GetValue(arg); + } else { + arg_value = GetValue(operand); + if (!arg_value) { + auto upstream_op = operand.getDefiningOp(); + arg_value = GetOpResult(upstream_op); } } + if (arg_value->is_type()) { + impl_->runtime->FeedInArgs( + std::make_pair(std::to_string(i), ValueRef(arg_value))); + } } #else CHECK(false) << "should not reach here"; diff --git a/paddle/infrt/kernel/tensor_kernels.cc b/paddle/infrt/kernel/tensor_kernels.cc index 79502f9fdf..a9077220cf 100644 --- a/paddle/infrt/kernel/tensor_kernels.cc +++ b/paddle/infrt/kernel/tensor_kernels.cc @@ -146,8 +146,8 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) { // TensorList related methods. #ifdef INFRT_WITH_PHI - registry->AddKernel("dt.tensor_list_get_tensor", - INFRT_KERNEL(TensorListGetTensor)); + registry->AddKernelWithAttrs( + "dt.tensor_list_get_tensor", INFRT_KERNEL(TensorListGetTensor), {"id"}); registry->AddKernel("dt.tensor_list_get_size", INFRT_KERNEL(TensorListGetSize)); #endif diff --git a/paddle/infrt/kernel/tensorrt/trt_helper.h b/paddle/infrt/kernel/tensorrt/trt_helper.h new file mode 100644 index 0000000000..96122bffac --- /dev/null +++ b/paddle/infrt/kernel/tensorrt/trt_helper.h @@ -0,0 +1,66 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "glog/logging.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace infrt { +namespace kernel { +namespace tensorrt { + +static nvinfer1::DataType TensorTypeToWeightType(phi::DataType tensor_type) { + switch (tensor_type) { + case phi::DataType::FLOAT32: + return nvinfer1::DataType::kFLOAT; + case phi::DataType::INT32: + return nvinfer1::DataType::kINT32; + case phi::DataType::FLOAT16: + return nvinfer1::DataType::kHALF; + default: + llvm_unreachable("should not reach here"); + } +} + +static nvinfer1::Dims ArrayAttrToNvDims(const mlir::ArrayAttr& int_array_attr) { + nvinfer1::Dims dims; + dims.nbDims = int_array_attr.size(); + CHECK(!int_array_attr.empty()); + CHECK(int_array_attr[0].getType().isIntOrIndex()); + for (int i = 0; i < dims.nbDims; ++i) { + dims.d[i] = int_array_attr[i].cast().getInt(); + } + return dims; +} + +static nvinfer1::Weights TensorToWeights(phi::DenseTensor* tensor) { + CHECK_NOTNULL(tensor); + nvinfer1::Weights ret; + ret.type = TensorTypeToWeightType(tensor->dtype()); + ret.count = tensor->numel(); + ret.values = tensor->data(); + return ret; +} + +} // namespace tensorrt +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/tensorrt/trt_kernels.cc b/paddle/infrt/kernel/tensorrt/trt_kernels.cc index 04847ac898..aa7609092b 100644 --- a/paddle/infrt/kernel/tensorrt/trt_kernels.cc +++ b/paddle/infrt/kernel/tensorrt/trt_kernels.cc @@ -21,13 +21,19 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" + +#include "paddle/infrt/kernel/tensorrt/trt_helper.h" +#include "paddle/infrt/kernel/tensorrt/trt_layers.h" + #include "paddle/infrt/backends/tensorrt/trt_engine.h" #include "paddle/infrt/backends/tensorrt/trt_options.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include "paddle/infrt/host_context/symbol_table.h" +#include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" namespace infrt { @@ -35,8 +41,7 @@ namespace kernel { namespace tensorrt { ::infrt::backends::tensorrt::TrtEngine CreateTrtEngine( - MlirOperationWithInfrtSymbol - create_engine_op /*, input_tensors, output_tensors, weights*/) { + MlirOperationWithInfrtSymbol create_engine_op) { // TODO(wilber): The device_id needs to get from mlir. int device_id = 0; backends::tensorrt::TrtEngine engine(device_id); @@ -51,6 +56,7 @@ namespace tensorrt { // TODO(wilber): The build option shoule be fiiled from mlir info. backends::tensorrt::BuildOptions options; options.max_batch = 4; + options.workspace = 1024; // Parse mlir Region which only has one block. mlir::Operation& operation = *create_engine_op.operation; @@ -62,8 +68,9 @@ namespace tensorrt { auto& region = operation.getRegion(0); auto& block = region.getBlocks().front(); - llvm::DenseMap map_info; std::unordered_map trt_bind_inputs; + ValueToITensorMap value_to_trt_tensor_map; + ValueToTensorMap value_to_tensor_map; for (auto index_operand : llvm::enumerate(operation.getOperands())) { mlir::Value operand = index_operand.value(); @@ -73,69 +80,72 @@ namespace tensorrt { auto* v = symbol_table->GetValue(std::to_string(idx)); CHECK_NOTNULL(v); auto* t = &v->get(); - trt_bind_inputs[input_name] = t; + value_to_tensor_map[operand] = t; + // TODO(wilber): get input info from mlir. + // TODO(wilber): input dims, now only support static_shape, and just remove - // the first dimension. + // the first dimension. If the first dim is not -1, maybe we can pass the + // origin dims. + // TODO(wilber): now only suppot float input. - nvinfer1::Dims dims; - dims.nbDims = t->dims().size() - 1; - for (int i = 0; i < dims.nbDims; ++i) { - dims.d[i] = t->dims()[i + 1]; - } - auto* in = - network->addInput(input_name.c_str(), nvinfer1::DataType::kFLOAT, dims); - map_info[operand] = in; - } - // TODO(wilber): Find a way to add layer. - for (auto& inner_op : block.without_terminator()) { - if (inner_op.getName().getStringRef() == "trt.Activation") { - trt::ActivationOp act_op = llvm::dyn_cast(inner_op); - auto in_arg = act_op.getOperand(); - if (!map_info.count(in_arg)) { - CHECK(false) << "map_info not has in_arg."; + if (operand.isa()) { + // TODO(wilber): A trick: the weights are CPU tensor and inputs are GPU + // tensor, so we treat all GPU tensors as inputs to trt. + if (t->place().GetType() == phi::AllocationType::GPU) { + trt_bind_inputs[input_name] = t; + nvinfer1::Dims dims; + dims.nbDims = t->dims().size() - 1; + for (int i = 0; i < dims.nbDims; ++i) { + dims.d[i] = t->dims()[i + 1]; + } + auto* in = network->addInput( + input_name.c_str(), nvinfer1::DataType::kFLOAT, dims); + value_to_trt_tensor_map[operand] = in; } - nvinfer1::ActivationType act_type = - static_cast(act_op.activation_type()); - auto* act_layer = network->addActivation(*map_info[in_arg], act_type); - act_layer->setAlpha(act_op.alpha().convertToFloat()); - act_layer->setBeta(act_op.beta().convertToFloat()); - for (size_t i = 0; i < act_op->getNumResults(); ++i) { - nvinfer1::ITensor* act_out_tensor = act_layer->getOutput(i); - mlir::Value act_out = act_op->getResult(i); - map_info[act_out] = act_out_tensor; + } else { + // TODO(wilber): Replace with the op name that generates the weights. + if (operand.getDefiningOp()->getName().getStringRef() != + "phi_dt.create_dense_tensor.cpu") { + trt_bind_inputs[input_name] = t; + nvinfer1::Dims dims; + dims.nbDims = t->dims().size() - 1; + for (int i = 0; i < dims.nbDims; ++i) { + dims.d[i] = t->dims()[i + 1]; + } + auto* in = network->addInput( + input_name.c_str(), nvinfer1::DataType::kFLOAT, dims); + value_to_trt_tensor_map[operand] = in; } } - - // if (inner_op.getName().getStringRef() == "trt.Constant") { - // trt::ConstantOp op = llvm::dyn_cast(inner_op); - // mlir::Value op_out = op.getResult(); - // std::vector weight_data{1}; - // auto* layer = network->addConstant(nvinfer1::Dims2(1, 1), - // nvinfer1::Weights{nvinfer1::DataType::kFLOAT, weight_data.data(), 1}); - // auto* op_out_tenor = layer->getOutput(0); - // map_info[op_out] = op_out_tenor; - // } } - for (auto& inner_op : block.without_terminator()) { - for (mlir::Value v : inner_op.getResults()) { - for (mlir::Operation* user : v.getUsers()) { - if (user->getName().getStringRef() == "infrt.return") { - if (!map_info.count(v)) { - CHECK(false) << "map_info not has value"; - } - network->markOutput(*map_info[v]); - } - } + + // TODO(wilber): Find a way to add layer. + for (auto& operation : block.without_terminator()) { + if (trt::ActivationOp op = llvm::dyn_cast(operation)) { + ActivationFunc( + op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); + } else if (trt::FullyConnectedOp op = + llvm::dyn_cast(operation)) { + FcFunc(op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); + } else if (trt::ConvolutionOp op = + llvm::dyn_cast(operation)) { + ConvFunc(op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); + } else { + CHECK(false) << "not supported operation."; } } - // std::unordered_map trt_bind_outputs; - mlir::Operation* ret = block.getTerminator(); - for (unsigned int i = 0; i < ret->getNumOperands(); ++i) { - mlir::Value arg = ret->getOperand(i); - CHECK(map_info.count(arg)); - map_info[arg]->setName(("output_" + std::to_string(i)).c_str()); + + for (auto index_operand : + llvm::enumerate(block.getTerminator()->getOperands())) { + mlir::Value arg = index_operand.value(); + CHECK(value_to_trt_tensor_map.count(arg)); + // TODO(wilber): A trick that we name trt output tensor's name as output_0, + // output_1, ... + value_to_trt_tensor_map[arg]->setName( + ("output_" + std::to_string(index_operand.index())).c_str()); + network->markOutput(*value_to_trt_tensor_map[arg]); } for (int i = 0; i < network->getNbOutputs(); ++i) { engine.PrepareOutputHandle(network->getOutput(i)->getName()); diff --git a/paddle/infrt/kernel/tensorrt/trt_layers.h b/paddle/infrt/kernel/tensorrt/trt_layers.h new file mode 100644 index 0000000000..19e20c170e --- /dev/null +++ b/paddle/infrt/kernel/tensorrt/trt_layers.h @@ -0,0 +1,104 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include + +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" +#include "paddle/infrt/kernel/tensorrt/trt_helper.h" + +#include "paddle/phi/core/dense_tensor.h" + +namespace infrt { +namespace kernel { +namespace tensorrt { + +using ValueToTensorMap = llvm::DenseMap; +using ValueToITensorMap = llvm::DenseMap; + +inline void ActivationFunc( + trt::ActivationOp& act_op, // NOLINT + nvinfer1::INetworkDefinition* network, + ValueToITensorMap& value_to_trt_tensor_map, // NOLINT + ValueToTensorMap& value_to_tensor_map) { // NOLINT + auto in_arg = act_op.getOperand(); + CHECK(value_to_trt_tensor_map.count(in_arg)) + << "value_to_trt_tensor_map not has in_arg."; + + nvinfer1::ActivationType act_type = + static_cast(act_op.activation_type()); + auto* act_layer = + network->addActivation(*value_to_trt_tensor_map[in_arg], act_type); + act_layer->setAlpha(act_op.alpha().convertToFloat()); + act_layer->setBeta(act_op.beta().convertToFloat()); + for (size_t i = 0; i < act_op->getNumResults(); ++i) { + nvinfer1::ITensor* act_out_tensor = act_layer->getOutput(i); + mlir::Value act_out = act_op->getResult(i); + value_to_trt_tensor_map[act_out] = act_out_tensor; + } +} + +inline void ConvFunc(trt::ConvolutionOp& op, // NOLINT + nvinfer1::INetworkDefinition* network, + ValueToITensorMap& value_to_trt_tensor_map, // NOLINT + ValueToTensorMap& value_to_tensor_map) { // NOLINT + mlir::Value input_tensor_repr = op.input_tensor(); + int out_channel_num = op.out_channel_num(); + auto size_attrs = op.kernel_size(); + nvinfer1::Dims dims = ArrayAttrToNvDims(size_attrs); + auto kernel_weights = + TensorToWeights(value_to_tensor_map[op.kernel_weights()]); + auto bias_weights = TensorToWeights(value_to_tensor_map[op.bias_weights()]); + + auto* layer = + network->addConvolutionNd(*value_to_trt_tensor_map[input_tensor_repr], + out_channel_num, + dims, + kernel_weights, + bias_weights); + CHECK_NOTNULL(layer); + mlir::Value out_repr = op.output_tensor(); + nvinfer1::ITensor* out_tensor = layer->getOutput(0); + value_to_trt_tensor_map[out_repr] = out_tensor; +} + +inline void FcFunc(trt::FullyConnectedOp& op, // NOLINT + nvinfer1::INetworkDefinition* network, + ValueToITensorMap& value_to_trt_tensor_map, // NOLINT + ValueToTensorMap& value_to_tensor_map) { // NOLINT + mlir::Value input_tensor_repr = op.input_tensor(); + CHECK(value_to_trt_tensor_map.count(input_tensor_repr)); + + auto kernel_weights = + TensorToWeights(value_to_tensor_map[op.kernel_weights()]); + auto bias_weights = TensorToWeights(value_to_tensor_map[op.bias_weights()]); + + int out_channel_num = op.out_channel_num(); + auto* layer = + network->addFullyConnected(*value_to_trt_tensor_map[input_tensor_repr], + out_channel_num, + kernel_weights, + bias_weights); + + mlir::Value out_repr = op.output_tensor(); + nvinfer1::ITensor* out_tensor = layer->getOutput(0); + value_to_trt_tensor_map[out_repr] = out_tensor; +} +} // namespace tensorrt +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/tests/dialect/disabled_trt.mlir b/paddle/infrt/tests/dialect/tensorrt/disabled_trt.mlir similarity index 100% rename from paddle/infrt/tests/dialect/disabled_trt.mlir rename to paddle/infrt/tests/dialect/tensorrt/disabled_trt.mlir diff --git a/paddle/infrt/tests/dialect/tensorrt/disabled_trt_conv.mlir b/paddle/infrt/tests/dialect/tensorrt/disabled_trt_conv.mlir new file mode 100644 index 0000000000..c67d47415b --- /dev/null +++ b/paddle/infrt/tests/dialect/tensorrt/disabled_trt_conv.mlir @@ -0,0 +1,54 @@ +// RUN: infrtexec -i %s | FileCheck %s + +// CHECK-LABEL: @run_trt +func @run_trt(%input_tensor : !infrt.dense_tensor, %kernel_weight : !infrt.dense_tensor, %kernel_bias : !infrt.dense_tensor, %gpu_ctx : !phi.context) { + %a = "trt.create_engine"(%input_tensor, %kernel_weight, %kernel_bias) ({ + %1 = "trt.Activation"(%input_tensor) {activation_type = 1 : si32, alpha = 1.0 : f32, beta = 6.0 : f32} : (!infrt.dense_tensor) -> !infrt.dense_tensor + %2 = "trt.Convolution"(%input_tensor, %kernel_weight, %kernel_bias) {out_channel_num = 3 : si32, kernel_size = [3:i32, 3:i32]} : (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor + "infrt.return"(%1, %2) : (!infrt.dense_tensor, !infrt.dense_tensor) -> () + }) : (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor) -> !trt.engine + "trt.inspect_engine"(%a) {} : (!trt.engine) -> () + + %res = "trt.compute"(%a, %gpu_ctx) {} : (!trt.engine, !phi.context) -> (!infrt.tensor_list) + %size = "dt.tensor_list_get_size"(%res) {} : (!infrt.tensor_list) -> (i32) + "infrt.print.i32"(%size) {} : (i32) -> () + + %ts0 = "dt.tensor_list_get_tensor"(%res) {id = 0 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor) + "phi_dt.print_tensor" (%ts0) : (!infrt.dense_tensor) -> () + + %ts1 = "dt.tensor_list_get_tensor"(%res) {id = 1 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor) + "phi_dt.print_tensor" (%ts1) : (!infrt.dense_tensor) -> () + + infrt.return +} + +// CHECK-LABEL: @main +func @main() { + %gpu_ctx = "phi_dt.create_context.gpu" (): () -> !phi.context + %cpu_ctx = "phi_dt.create_context.cpu" (): () -> !phi.context + + %input_tensor = "phi_dt.create_dense_tensor.gpu" (%gpu_ctx) { + precision=#infrt.precision, + layout=#infrt.layout, + dims=[1:i64, 3:i64, 28:i64, 28:i64], lod=[0:i64]}: (!phi.context) -> (!infrt.dense_tensor) + "phi_dt.fill_dense_tensor.f32"(%input_tensor) {value=[3.8:f32, 2.4:f32, 1.3:f32]} : (!infrt.dense_tensor) -> () + // "phi_dt.print_tensor" (%input_tensor) : (!infrt.dense_tensor) -> () + + %kernel_weight = "phi_dt.create_dense_tensor.cpu"(%cpu_ctx) { + precision=#infrt.precision, + layout=#infrt.layout, + dims=[3:i64, 3:i64, 3:i64, 3:i64], lod=[0:i64]} : (!phi.context) -> (!infrt.dense_tensor) + "phi_dt.fill_dense_tensor.f32"(%kernel_weight) {value=[1.:f32, 2.:f32, 3.:f32, 4.:f32, 5.:f32, 6.:f32]} : (!infrt.dense_tensor) -> () + // "phi_dt.print_tensor" (%kernel_weight) : (!infrt.dense_tensor) -> () + + %kernel_bias = "phi_dt.create_dense_tensor.cpu"(%cpu_ctx) { + precision=#infrt.precision, + layout=#infrt.layout, + dims=[3:i64], lod=[0:i64]} : (!phi.context) -> (!infrt.dense_tensor) + "phi_dt.fill_dense_tensor.f32"(%kernel_bias) {value=[1.:f32]} : (!infrt.dense_tensor) -> () + // "phi_dt.print_tensor" (%kernel_bias) : (!infrt.dense_tensor) -> () + + infrt.call @run_trt(%input_tensor, %kernel_weight, %kernel_bias, %gpu_ctx) : (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor, !phi.context) -> () + + infrt.return +} diff --git a/paddle/infrt/tests/dialect/tensorrt/disabled_trt_fc.mlir b/paddle/infrt/tests/dialect/tensorrt/disabled_trt_fc.mlir new file mode 100644 index 0000000000..78dc4ac1c1 --- /dev/null +++ b/paddle/infrt/tests/dialect/tensorrt/disabled_trt_fc.mlir @@ -0,0 +1,46 @@ +// RUN: infrtexec -i %s | FileCheck %s + +// CHECK-LABEL: @main +func @main() { + %ctx = "phi_dt.create_context.gpu" (): () -> !phi.context + %cpu_ctx = "phi_dt.create_context.cpu" (): () -> !phi.context + + %input_tensor = "phi_dt.create_dense_tensor.gpu" (%ctx) { + precision=#infrt.precision, + layout=#infrt.layout, + dims=[1:i64, 3:i64, 1:i64, 1:i64], lod=[1:i64]}: (!phi.context) -> (!infrt.dense_tensor) + "phi_dt.fill_dense_tensor.f32"(%input_tensor) {value=[3.8:f32, 2.4:f32, 1.3:f32]} : (!infrt.dense_tensor) -> () + //"phi_dt.print_tensor" (%input_tensor) : (!infrt.dense_tensor) -> () + + %kernel_weight = "phi_dt.create_dense_tensor.cpu"(%cpu_ctx) { + precision=#infrt.precision, + layout=#infrt.layout, + dims=[2:i64, 3:i64], lod=[1:i64]} : (!phi.context) -> (!infrt.dense_tensor) + "phi_dt.fill_dense_tensor.f32"(%kernel_weight) {value=[1.:f32, 2.:f32, 3.:f32, 4.:f32, 5.:f32, 6.:f32]} : (!infrt.dense_tensor) -> () + //"phi_dt.print_tensor" (%kernel_weight) : (!infrt.dense_tensor) -> () + + %kernel_bias = "phi_dt.create_dense_tensor.cpu"(%cpu_ctx) { + precision=#infrt.precision, + layout=#infrt.layout, + dims=[2:i64], lod=[1:i64]} : (!phi.context) -> (!infrt.dense_tensor) + "phi_dt.fill_dense_tensor.f32"(%kernel_bias) {value=[1.:f32, 2.:f32]} : (!infrt.dense_tensor) -> () + //"phi_dt.print_tensor" (%kernel_bias) : (!infrt.dense_tensor) -> () + + %engine = "trt.create_engine"(%input_tensor, %kernel_weight, %kernel_bias) ({ + %1 = "trt.Activation"(%input_tensor) {activation_type = 1 : si32, alpha = 1.0 : f32, beta = 6.0 : f32} : (!infrt.dense_tensor) -> !infrt.dense_tensor + %2 = "trt.FullyConnected"(%input_tensor, %kernel_weight, %kernel_bias) {out_channel_num = 2 : si32} : (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor + "infrt.return"(%1, %2) : (!infrt.dense_tensor, !infrt.dense_tensor) -> () + }) : (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor) -> !trt.engine + + %res = "trt.compute"(%engine, %ctx) {} : (!trt.engine, !phi.context) -> (!infrt.tensor_list) + %size = "dt.tensor_list_get_size"(%res) {} : (!infrt.tensor_list) -> (i32) + "infrt.print.i32"(%size) {} : (i32) -> () + + %ts0 = "dt.tensor_list_get_tensor"(%res) {id = 0 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor) + "phi_dt.print_tensor" (%ts0) : (!infrt.dense_tensor) -> () + + %ts1 = "dt.tensor_list_get_tensor"(%res) {id = 1 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor) + "phi_dt.print_tensor" (%ts1) : (!infrt.dense_tensor) -> () + + infrt.return +} -- GitLab