From 0701160a5b9b1b18a0337b4e2f75c2cbab1f9899 Mon Sep 17 00:00:00 2001 From: Wilber Date: Thu, 7 Apr 2022 10:54:36 +0800 Subject: [PATCH] infrt-trt run resnet50 (#41442) * add rewrite pattern form paddle op tp trt op * infrt-trt run resnet50. Co-authored-by: weishengying <1343838695@qq.com> --- paddle/infrt/CMakeLists.txt | 3 - paddle/infrt/backends/tensorrt/CMakeLists.txt | 8 +- .../backends/tensorrt/plugin/CMakeLists.txt | 1 + .../backends/tensorrt/plugin/plugin_utils.h | 153 ++++++++++ .../tensorrt/plugin/pool_op_plugin.cu | 288 ++++++++++++++++++ .../backends/tensorrt/plugin/pool_op_plugin.h | 196 ++++++++++++ paddle/infrt/dialect/tensorrt/convert.h | 6 +- paddle/infrt/dialect/tensorrt/trt_exec.cc | 4 +- .../dialect/tensorrt/trt_op_converter_pass.cc | 4 +- paddle/infrt/kernel/tensorrt/trt_helper.h | 12 + paddle/infrt/kernel/tensorrt/trt_kernels.cc | 4 + paddle/infrt/kernel/tensorrt/trt_layers.h | 103 ++++++- 12 files changed, 756 insertions(+), 26 deletions(-) create mode 100644 paddle/infrt/backends/tensorrt/plugin/CMakeLists.txt create mode 100644 paddle/infrt/backends/tensorrt/plugin/plugin_utils.h create mode 100644 paddle/infrt/backends/tensorrt/plugin/pool_op_plugin.cu create mode 100644 paddle/infrt/backends/tensorrt/plugin/pool_op_plugin.h diff --git a/paddle/infrt/CMakeLists.txt b/paddle/infrt/CMakeLists.txt index e777a8e3ab4..0f90ec96db2 100644 --- a/paddle/infrt/CMakeLists.txt +++ b/paddle/infrt/CMakeLists.txt @@ -115,9 +115,6 @@ if (INFRT_WITH_PHI) endif() cc_library(infrt SHARED SRCS ${infrt_src} DEPS glog boost ${mlir_libs} ${phi_libs} paddle_framework_proto infrt_naive) -if (INFRT_WITH_TRT) - target_link_libraries(infrt infrt_trt) -endif() cc_library(infrt_static SRCS ${infrt_src} DEPS glog boost ${mlir_libs} ${phi_libs} paddle_framework_proto) add_dependencies(infrt ${infrt_mlir_incs} mlir-headers) diff --git a/paddle/infrt/backends/tensorrt/CMakeLists.txt b/paddle/infrt/backends/tensorrt/CMakeLists.txt index cc20c9a2e14..672515ea4b7 100644 --- a/paddle/infrt/backends/tensorrt/CMakeLists.txt +++ b/paddle/infrt/backends/tensorrt/CMakeLists.txt @@ -1,3 +1,7 @@ -cc_library(infrt_trt SRCS trt_engine.cc DEPS glog phi_dynload_cuda phi) +add_subdirectory(plugin) -cc_test_tiny(test_infrt_trt SRCS test_trt_engine.cc DEPS infrt_trt phi_dynload_cuda tensorrt_converter) +core_gather_headers() + +gather_srcs(infrt_src SRCS trt_engine.cc) + +cc_test_tiny(test_infrt_trt SRCS test_trt_engine.cc DEPS infrt phi_dynload_cuda tensorrt_converter) diff --git a/paddle/infrt/backends/tensorrt/plugin/CMakeLists.txt b/paddle/infrt/backends/tensorrt/plugin/CMakeLists.txt new file mode 100644 index 00000000000..8848148f2c6 --- /dev/null +++ b/paddle/infrt/backends/tensorrt/plugin/CMakeLists.txt @@ -0,0 +1 @@ +gather_srcs(infrt_src SRCS pool_op_plugin.cu) diff --git a/paddle/infrt/backends/tensorrt/plugin/plugin_utils.h b/paddle/infrt/backends/tensorrt/plugin/plugin_utils.h new file mode 100644 index 00000000000..49e96e6eab0 --- /dev/null +++ b/paddle/infrt/backends/tensorrt/plugin/plugin_utils.h @@ -0,0 +1,153 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include +#include +#include +#include +#include + +#include "paddle/phi/backends/dynload/tensorrt.h" + +namespace infrt { +namespace backends { +namespace tensorrt { +namespace plugin { + +template +inline void SerializeValue(void** buffer, T const& value); + +template +inline void DeserializeValue(void const** buffer, + size_t* buffer_size, + T* value); + +namespace details { + +template +struct Serializer {}; + +template +struct Serializer::value || + std::is_enum::value || + std::is_pod::value>::type> { + static size_t SerializedSize(T const& value) { return sizeof(T); } + + static void Serialize(void** buffer, T const& value) { + std::memcpy(*buffer, &value, sizeof(T)); + reinterpret_cast(*buffer) += sizeof(T); + } + + static void Deserialize(void const** buffer, size_t* buffer_size, T* value) { + assert(*buffer_size >= sizeof(T)); + std::memcpy(value, *buffer, sizeof(T)); + reinterpret_cast(*buffer) += sizeof(T); + *buffer_size -= sizeof(T); + } +}; + +template <> +struct Serializer { + static size_t SerializedSize(const char* value) { return strlen(value) + 1; } + + static void Serialize(void** buffer, const char* value) { + std::strcpy(static_cast(*buffer), value); // NOLINT + reinterpret_cast(*buffer) += strlen(value) + 1; + } + + static void Deserialize(void const** buffer, + size_t* buffer_size, + const char** value) { + *value = static_cast(*buffer); + size_t data_size = strnlen(*value, *buffer_size) + 1; + assert(*buffer_size >= data_size); + reinterpret_cast(*buffer) += data_size; + *buffer_size -= data_size; + } +}; + +template +struct Serializer, + typename std::enable_if::value || + std::is_enum::value || + std::is_pod::value>::type> { + static size_t SerializedSize(std::vector const& value) { + return sizeof(value.size()) + value.size() * sizeof(T); + } + + static void Serialize(void** buffer, std::vector const& value) { + SerializeValue(buffer, value.size()); + size_t nbyte = value.size() * sizeof(T); + std::memcpy(*buffer, value.data(), nbyte); + reinterpret_cast(*buffer) += nbyte; + } + + static void Deserialize(void const** buffer, + size_t* buffer_size, + std::vector* value) { + size_t size; + DeserializeValue(buffer, buffer_size, &size); + value->resize(size); + size_t nbyte = value->size() * sizeof(T); + CHECK_GE(*buffer_size, nbyte); + std::memcpy(value->data(), *buffer, nbyte); + reinterpret_cast(*buffer) += nbyte; + *buffer_size -= nbyte; + } +}; + +} // namespace details + +template +inline size_t SerializedSize(T const& value) { + return details::Serializer::SerializedSize(value); +} + +template +inline void SerializeValue(void** buffer, T const& value) { + return details::Serializer::Serialize(buffer, value); +} + +template +inline void DeserializeValue(void const** buffer, + size_t* buffer_size, + T* value) { + return details::Serializer::Deserialize(buffer, buffer_size, value); +} + +template +class TrtPluginRegistrar { + public: + TrtPluginRegistrar() { + static auto func_ptr = static_cast( + ::phi::dynload::getPluginRegistry()); + func_ptr->registerCreator(instance, ""); + } + + private: + //! Plugin instance. + T instance{}; +}; + +#define REGISTER_TRT_PLUGIN(name) \ + static TrtPluginRegistrar pluginRegistrar##name {} + +} // namespace plugin +} // namespace tensorrt +} // namespace backends +} // namespace infrt diff --git a/paddle/infrt/backends/tensorrt/plugin/pool_op_plugin.cu b/paddle/infrt/backends/tensorrt/plugin/pool_op_plugin.cu new file mode 100644 index 00000000000..5a53777c8e3 --- /dev/null +++ b/paddle/infrt/backends/tensorrt/plugin/pool_op_plugin.cu @@ -0,0 +1,288 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "glog/logging.h" +#include "paddle/infrt/backends/tensorrt/plugin/plugin_utils.h" +#include "paddle/infrt/backends/tensorrt/plugin/pool_op_plugin.h" +#include "paddle/phi/kernels/funcs/pooling.h" + +namespace infrt { +namespace backends { +namespace tensorrt { +namespace plugin { + +PoolPlugin::PoolPlugin(bool ceil_mode, + PoolType pool_type, + bool adaptive, + bool exclusive, + std::vector ksize, + std::vector strides, + std::vector paddings, + std::vector input_shape, + std::vector real_paddings) + : ceil_mode_(ceil_mode), + pool_type_(pool_type), + adaptive_(adaptive), + exclusive_(exclusive), + ksize_(ksize), + strides_(strides), + paddings_(paddings), + real_paddings_(real_paddings), + input_shape_(input_shape) { + output_shape_ = input_shape_; + std::vector output_shape = + CalcOutputSize({input_shape_[1], input_shape_[2]}, + ceil_mode_, + adaptive_, + ksize_, + strides_, + real_paddings_); + output_shape_[1] = output_shape[0]; + output_shape_[2] = output_shape[1]; +} + +PoolPlugin::PoolPlugin(void const* serialData, size_t serialLength) { + // deserializeBase(serialData, serialLength); + DeserializeValue(&serialData, &serialLength, &ceil_mode_); + DeserializeValue(&serialData, &serialLength, &pool_type_); + DeserializeValue(&serialData, &serialLength, &adaptive_); + DeserializeValue(&serialData, &serialLength, &exclusive_); + DeserializeValue(&serialData, &serialLength, &ksize_); + DeserializeValue(&serialData, &serialLength, &strides_); + DeserializeValue(&serialData, &serialLength, &paddings_); + DeserializeValue(&serialData, &serialLength, &real_paddings_); + DeserializeValue(&serialData, &serialLength, &input_shape_); + DeserializeValue(&serialData, &serialLength, &output_shape_); +} + +const char* PoolPlugin::getPluginType() const noexcept { return "pool_plugin"; } + +const char* PoolPlugin::getPluginVersion() const noexcept { return "1"; } + +int PoolPlugin::getNbOutputs() const noexcept { return 1; } + +nvinfer1::Dims PoolPlugin::getOutputDimensions(int outputIndex, + const nvinfer1::Dims* inputs, + int nbInputs) noexcept { + assert(nbInputs == 1); + assert(index == 0); + assert(inputs[0].nbDims == 3); + nvinfer1::Dims const& input_dims = inputs[0]; + + nvinfer1::Dims output_dims = input_dims; + + output_dims.d[1] = output_shape_[1]; + output_dims.d[2] = output_shape_[2]; + return output_dims; +} + +int32_t PoolPlugin::initialize() noexcept { return 0; } + +void PoolPlugin::terminate() noexcept {} + +size_t PoolPlugin::getWorkspaceSize(int32_t maxBatchSize) const noexcept { + return 0; +} + +#if IS_TRT_VERSION_LT(8000) +int PoolPlugin::enqueue(int batch_size, + const void* const* inputs, + void** outputs, +#else +int PoolPlugin::enqueue(int batch_size, + const void* const* inputs, + void* const* outputs, +#endif + void* workspace, + cudaStream_t stream) noexcept { + // TODO(wilber) + int input_size = 0; + float const* idata = reinterpret_cast(inputs[0]); + float* const* odatas = reinterpret_cast(outputs); + + std::vector input_shape = input_shape_; + std::vector output_shape = output_shape_; + input_shape.insert(input_shape.begin(), batch_size); + output_shape.insert(output_shape.begin(), batch_size); + + if (pool_type_ == PoolType::max) { + ::phi::funcs::MaxPool pool_process; + ::phi::funcs::Pool2dDirectCUDAFunctor, float> + pool2d_forward; + pool2d_forward(idata, + input_shape, + output_shape, + ksize_, + strides_, + paddings_, + true, + false, + odatas[0], + stream, + pool_process); + } else if (pool_type_ == PoolType::avg) { + ::phi::funcs::AvgPool pool_process; + ::phi::funcs::Pool2dDirectCUDAFunctor, float> + pool2d_forward; + pool2d_forward(idata, + input_shape, + output_shape, + ksize_, + strides_, + paddings_, + exclusive_, + adaptive_, + odatas[0], + stream, + pool_process); + } + + return cudaGetLastError() != cudaSuccess; +} + +// TODO(wilber): serialize base info? +size_t PoolPlugin::getSerializationSize() const noexcept { + return SerializedSize(ceil_mode_) + SerializedSize(pool_type_) + + SerializedSize(adaptive_) + SerializedSize(exclusive_) + + SerializedSize(ksize_) + SerializedSize(strides_) + + SerializedSize(paddings_) + SerializedSize(real_paddings_) + + SerializedSize(input_shape_) + SerializedSize(output_shape_); +} +// TODO(wilber): serialize base info? +void PoolPlugin::serialize(void* buffer) const noexcept { + // serializeBase(buffer); + SerializeValue(&buffer, ceil_mode_); + SerializeValue(&buffer, pool_type_); + SerializeValue(&buffer, adaptive_); + SerializeValue(&buffer, exclusive_); + SerializeValue(&buffer, ksize_); + SerializeValue(&buffer, strides_); + SerializeValue(&buffer, paddings_); + SerializeValue(&buffer, real_paddings_); + SerializeValue(&buffer, input_shape_); + SerializeValue(&buffer, output_shape_); +} + +void PoolPlugin::destroy() noexcept { delete this; } + +void PoolPlugin::setPluginNamespace(char const* plugin_namespace) noexcept { + namespace_ = plugin_namespace; +} + +char const* PoolPlugin::getPluginNamespace() const noexcept { + return namespace_.c_str(); +} + +nvinfer1::DataType PoolPlugin::getOutputDataType( + int32_t index, + nvinfer1::DataType const* input_types, + int32_t nbInputs) const noexcept { + CHECK_EQ(index, 0); + CHECK_EQ((input_types[0] == nvinfer1::DataType::kFLOAT), true); + return input_types[0]; +} + +bool PoolPlugin::isOutputBroadcastAcrossBatch(int32_t outputIndex, + bool const* inputIsBroadcasted, + int32_t nbInputs) const noexcept { + return false; +} + +bool PoolPlugin::canBroadcastInputAcrossBatch(int32_t inputIndex) const + noexcept { + return false; +} + +nvinfer1::IPluginV2Ext* PoolPlugin::clone() const noexcept { + auto* plugin = new PoolPlugin(ceil_mode_, + pool_type_, + adaptive_, + exclusive_, + ksize_, + strides_, + paddings_, + input_shape_, + real_paddings_); + plugin->setPluginNamespace(namespace_.c_str()); + return plugin; +} + +void PoolPlugin::configurePlugin(nvinfer1::PluginTensorDesc const* in, + int32_t nb_input, + nvinfer1::PluginTensorDesc const* out, + int32_t nb_output) noexcept { + CHECK_EQ(nb_input, 1); + CHECK_EQ(nb_output, 1); + + input_dims_ = in[0].dims; + data_format_ = in[0].format; + data_type_ = in[0].type; +} + +bool PoolPlugin::supportsFormatCombination( + int32_t pos, + nvinfer1::PluginTensorDesc const* in_out, + int32_t nb_inputs, + int32_t nb_outputs) const noexcept { + CHECK_LT(pos, nb_inputs + nb_outputs); + CHECK_NOTNULL(in_out); + + return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) && + in_out[pos].format == nvinfer1::PluginFormat::kLINEAR); +} + +nvinfer1::IPluginV2* PoolPluginCreator::createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept { + // auto* plugin = new UffPoolPluginV2(*fc); + field_collection_ = *fc; + plugin_name_ = name; + const nvinfer1::PluginField* fields = fc->fields; + + bool ceil_mode; + PoolPlugin::PoolType pool_type; + bool adaptive; + bool exclusive; + std::vector ksize; + std::vector strides; + std::vector paddings; + std::vector real_paddings; + std::vector input_shape; + std::vector output_shape; + + // TODO(wilber): add implement. + CHECK(false) << "not implement"; + // for (int i = 0; i < fc->nbFields; ++i) { + // const char* attr_name = fields[i].name; + // if (!strcmp(attr_name, "ceil_mode")) { + // CHECK_EQ(fields[i].type == nvinfer1::PluginFieldType::kINT8, true); + // ceil_mode = *static_cast(fields[i].data); + // // mParam.numOutputBoxesPerClass = + // // *(static_cast(fields[i].data)); + // } + // } + + return nullptr; +} + +nvinfer1::IPluginV2* PoolPluginCreator::deserializePlugin( + const char* name, const void* serialData, size_t serialLength) noexcept { + auto* plugin = new PoolPlugin(serialData, serialLength); + plugin_name_ = name; + return plugin; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace backends +} // namespace infrt diff --git a/paddle/infrt/backends/tensorrt/plugin/pool_op_plugin.h b/paddle/infrt/backends/tensorrt/plugin/pool_op_plugin.h new file mode 100644 index 00000000000..0da1d158453 --- /dev/null +++ b/paddle/infrt/backends/tensorrt/plugin/pool_op_plugin.h @@ -0,0 +1,196 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include +#include +#include + +#include "paddle/infrt/backends/tensorrt/plugin/plugin_utils.h" +#include "paddle/infrt/backends/tensorrt/trt_utils.h" + +namespace infrt { +namespace backends { +namespace tensorrt { +namespace plugin { + +static std::vector CalcOutputSize(const std::vector& input_shape, + const bool& ceil_mode, + const bool& adaptive, + const std::vector& ksize, + const std::vector& strides, + const std::vector& real_paddings) { + std::vector output_shape = input_shape; + if (adaptive) { + output_shape[0] = ksize[0]; + output_shape[1] = ksize[1]; + } else { + int output_h = 0, output_w = 0; + if (ceil_mode) { + output_h = (input_shape[0] - ksize[0] + real_paddings[0] + + real_paddings[1] + strides[0] - 1) / + strides[0] + + 1; + output_w = (input_shape[1] - ksize[1] + real_paddings[2] + + real_paddings[3] + strides[1] - 1) / + strides[1] + + 1; + } + // TRT will use native layer when ceil_model=false + /* + else{ + output_h = (input_shape[0] - ksize[0] + real_paddings[0] + + real_paddings[1]) / strides[0] + 1; + output_w = (input_shape[1] - ksize[1] + real_paddings[2] + + real_paddings[3]) / strides[1] + 1; + } + */ + output_shape[0] = output_h; + output_shape[1] = output_w; + } + return output_shape; +} + +class PoolPlugin : public nvinfer1::IPluginV2IOExt { + public: + enum class PoolType { + max = 0, + avg, + }; + + PoolPlugin() {} + PoolPlugin(bool ceil_mode, + PoolType pool_type, + bool adaptive, + bool exclusive, + std::vector ksize, + std::vector strides, + std::vector paddings, + std::vector input_shape, + std::vector real_paddings); + + PoolPlugin(void const* serialData, size_t serialLength); + + // IPluginV2 methods + const char* getPluginType() const noexcept override; + const char* getPluginVersion() const noexcept override; + int getNbOutputs() const noexcept override; + nvinfer1::Dims getOutputDimensions(int outputIndex, + const nvinfer1::Dims* inputs, + int nbInputs) noexcept override; + int32_t initialize() noexcept override; + void terminate() noexcept override; + size_t getWorkspaceSize(int32_t maxBatchSize) const noexcept override; +#if IS_TRT_VERSION_LT(8000) + int enqueue(int batchSize, + const void* const* inputs, + void** outputs, +#else + int enqueue(int batchSize, + const void* const* inputs, + void* const* outputs, +#endif + void* workspace, + cudaStream_t stream) noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + + // IPluginV2Ext methods + nvinfer1::DataType getOutputDataType(int32_t index, + nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const + noexcept override; + bool isOutputBroadcastAcrossBatch(int32_t outputIndex, + bool const* inputIsBroadcasted, + int32_t nbInputs) const noexcept override; + bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept override; + // void attachToContext(cudnnContext*, + // cublasContext*, + // IGpuAllocator*) noexcept override; + // void detachFromContext() noexcept override; + IPluginV2Ext* clone() const noexcept override; + + // IPluginV2IOExt methods + void configurePlugin(nvinfer1::PluginTensorDesc const* in, + int32_t nb_input, + nvinfer1::PluginTensorDesc const* out, + int32_t nb_output) noexcept override; + bool supportsFormatCombination(int32_t pos, + nvinfer1::PluginTensorDesc const* inOut, + int32_t nb_inputs, + int32_t nb_outputs) const noexcept override; + + private: + bool ceil_mode_; + PoolType pool_type_; + bool adaptive_; + bool exclusive_; + std::vector ksize_; + std::vector strides_; + std::vector paddings_; + std::vector real_paddings_; + std::vector input_shape_; + std::vector output_shape_; + + private: + nvinfer1::Dims input_dims_; + nvinfer1::DataType data_type_; + nvinfer1::PluginFormat data_format_; + std::string namespace_; +}; + +class PoolPluginCreator : public nvinfer1::IPluginCreator { + public: + const char* getPluginName() const noexcept override { return "pool_plugin"; } + + const char* getPluginVersion() const noexcept override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) noexcept override; + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) noexcept override; + + void setPluginNamespace(const char* plugin_namespace) noexcept override { + plugin_namespace_ = plugin_namespace; + } + + const char* getPluginNamespace() const noexcept override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; +}; +REGISTER_TRT_PLUGIN(PoolPluginCreator); + +} // namespace plugin +} // namespace tensorrt +} // namespace backends +} // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/convert.h b/paddle/infrt/dialect/tensorrt/convert.h index 5b9e4a90745..be363e77848 100644 --- a/paddle/infrt/dialect/tensorrt/convert.h +++ b/paddle/infrt/dialect/tensorrt/convert.h @@ -320,9 +320,9 @@ inline ::llvm::SmallVector<::mlir::Value, 4> CreatePaddleTrtPoolingOp( } // if global_pooling == true or adaptive == true, padding will be ignored - if (global_pooling.getValue() || adaptive.getValue()) { - paddings_attr = builder.getI32ArrayAttr({0, 0}); - } + // if (global_pooling.getValue() || adaptive.getValue()) { + // paddings_attr = builder.getI32ArrayAttr({0, 0}); + // } // if global_pooling == true, then we should update kernel size to input dims. if (global_pooling.getValue() == true) { diff --git a/paddle/infrt/dialect/tensorrt/trt_exec.cc b/paddle/infrt/dialect/tensorrt/trt_exec.cc index b37186ada6d..837ca209374 100644 --- a/paddle/infrt/dialect/tensorrt/trt_exec.cc +++ b/paddle/infrt/dialect/tensorrt/trt_exec.cc @@ -72,7 +72,7 @@ int main(int argc, char** argv) { #endif context->loadAllAvailableDialects(); - module->dump(); + // module->dump(); mlir::PassManager pm(context); mlir::OpPassManager& trt_pass_manager = pm.nest(); @@ -87,7 +87,7 @@ int main(int argc, char** argv) { std::cout << "\npass failed!\n" << std::endl; return 4; } - module->dump(); + // module->dump(); ::infrt::host_context::TestMlir(module.get(), ®istry); return 0; } diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc index 5273bcaa6aa..e40bbd67c0b 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc @@ -186,7 +186,7 @@ struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern { create_scale_tensor_op->getLoc(), create_scale_tensor_op.output().getType(), create_scale_tensor_op.context(), - create_bias_tensor_op.dims(), + create_scale_tensor_op.dims(), ::infrt::LayoutAttr::get(rewriter.getContext(), ::infrt::LayoutType::NCHW), create_scale_tensor_op.lod(), @@ -206,7 +206,6 @@ struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern { rewriter.getF32ArrayAttr(combile_bias_data)); rewriter.replaceOp(create_bias_tensor_op, new_bias_op->getResults()); - rewriter.setInsertionPoint(op); trt::ScaleNdOp scaleNd_op; // resultTypes ::mlir::SmallVector<::mlir::Type, 4> resultTypes; @@ -215,6 +214,7 @@ struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern { } // attributes + rewriter.setInsertionPoint(op); ::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes; auto result = rewriter .create( diff --git a/paddle/infrt/kernel/tensorrt/trt_helper.h b/paddle/infrt/kernel/tensorrt/trt_helper.h index 13529430d68..4f1f1dde38c 100644 --- a/paddle/infrt/kernel/tensorrt/trt_helper.h +++ b/paddle/infrt/kernel/tensorrt/trt_helper.h @@ -52,6 +52,18 @@ static nvinfer1::Dims ArrayAttrToNvDims(const mlir::ArrayAttr& int_array_attr) { return dims; } +template +static std::vector ArrayAttrToVec(const mlir::ArrayAttr& int_array_attr) { + std::vector ret; + ret.resize(int_array_attr.size()); + CHECK(!int_array_attr.empty()); + CHECK(int_array_attr[0].getType().isIntOrIndex()); + for (size_t i = 0; i < int_array_attr.size(); ++i) { + ret[i] = int_array_attr[i].cast().getInt(); + } + return ret; +} + static nvinfer1::Weights TensorToWeights(::phi::DenseTensor* tensor) { CHECK_NOTNULL(tensor); nvinfer1::Weights ret; diff --git a/paddle/infrt/kernel/tensorrt/trt_kernels.cc b/paddle/infrt/kernel/tensorrt/trt_kernels.cc index 9b7fb200093..c182dda2705 100644 --- a/paddle/infrt/kernel/tensorrt/trt_kernels.cc +++ b/paddle/infrt/kernel/tensorrt/trt_kernels.cc @@ -147,6 +147,10 @@ namespace tensorrt { } else if (trt::ScaleNdOp op = llvm::dyn_cast(operation)) { ScaleNdFunc( op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); + } else if (trt::ElementWiseOp op = + llvm::dyn_cast(operation)) { + EltwiseFunc( + op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); } else { CHECK(false) << "not supported operation."; } diff --git a/paddle/infrt/kernel/tensorrt/trt_layers.h b/paddle/infrt/kernel/tensorrt/trt_layers.h index 8c7dd4d8132..9d8eba0bb31 100644 --- a/paddle/infrt/kernel/tensorrt/trt_layers.h +++ b/paddle/infrt/kernel/tensorrt/trt_layers.h @@ -22,6 +22,7 @@ #include +#include "paddle/infrt/backends/tensorrt/plugin/pool_op_plugin.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include "paddle/infrt/kernel/tensorrt/trt_helper.h" #include "paddle/phi/core/dense_tensor.h" @@ -78,6 +79,9 @@ inline void ConvFunc(trt::ConvolutionOp& op, // NOLINT dims, kernel_weights, bias_weights); + + layer->setPaddingNd(ArrayAttrToNvDims(op.paddings())); + layer->setStrideNd(ArrayAttrToNvDims(op.strides())); CHECK_NOTNULL(layer); mlir::Value out_repr = op.output_tensor(); nvinfer1::ITensor* out_tensor = layer->getOutput(0); @@ -90,8 +94,8 @@ inline void PoolFunc(trt::PoolingOp& op, // NOLINT ValueToTensorMap& value_to_tensor_map) { // NOLINT mlir::Value input_tensor_repr = op.input_tensor(); nvinfer1::ITensor* input_itensor = value_to_trt_tensor_map[input_tensor_repr]; - // nvinfer1::Dims input_shape = input_itensor->getDimensions(); - // int input_dims = input_shape.nbDims; + nvinfer1::Dims input_shape = input_itensor->getDimensions(); + int input_dims = input_shape.nbDims; auto padding_mode = op.padding_mode(); auto pool_type = op.pool_type(); @@ -109,7 +113,35 @@ inline void PoolFunc(trt::PoolingOp& op, // NOLINT if (adaptive) { // TODO(Inference) - CHECK(false) << "Not supported adaptive pool"; + // CHECK(false) << "Not supported adaptive pool"; + + std::vector input_shape_v; + for (int i = 0; i < input_dims; i++) { + input_shape_v.push_back(input_shape.d[i]); + } + auto paddings_val = ArrayAttrToVec(paddings); + std::vector real_paddings = paddings_val; + for (int i = 0; i < 2; ++i) { + int copy_pad = *(paddings_val.begin() + i); + real_paddings.insert(real_paddings.begin() + 2 * i + 1, copy_pad); + } + + auto* plugin = new backends::tensorrt::plugin::PoolPlugin( + false, + backends::tensorrt::plugin::PoolPlugin::PoolType::avg, + adaptive, + exclusive, + ArrayAttrToVec(ksize), + ArrayAttrToVec(strides), + paddings_val, + input_shape_v, + real_paddings); + auto* layer = network->addPluginV2(&input_itensor, 1, *plugin); + + mlir::Value out_repr = op.output_tensor(); + nvinfer1::ITensor* out_tensor = layer->getOutput(0); + value_to_trt_tensor_map[out_repr] = out_tensor; + return; } nvinfer1::Dims window_size = ArrayAttrToNvDims(ksize); @@ -136,19 +168,41 @@ inline void FcFunc(trt::FullyConnectedOp& op, // NOLINT mlir::Value input_tensor_repr = op.input_tensor(); CHECK(value_to_trt_tensor_map.count(input_tensor_repr)); + nvinfer1::ITensor* input_itensor = value_to_trt_tensor_map[input_tensor_repr]; + nvinfer1::Dims input_shape = input_itensor->getDimensions(); + int input_dims = input_shape.nbDims; + CHECK_EQ(input_dims, 1) << "Now we only support 2-d input."; + // TODO(wilber): We should place the logic to ir. Now only support 2-d input + // and we reshape to 4-d. + nvinfer1::Dims reshape_before_fc_dim; + reshape_before_fc_dim.nbDims = input_dims + 2; + // padding shape "* x q x 1 x 1" + for (int i = 0; i < reshape_before_fc_dim.nbDims; i++) { + reshape_before_fc_dim.d[i] = 1; + } + reshape_before_fc_dim.d[0] = input_shape.d[0]; + auto* reshape_before_fc_layer = network->addShuffle(*input_itensor); + reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); + + auto* reshape_itensor = reshape_before_fc_layer->getOutput(0); + auto kernel_weights = TensorToWeights(value_to_tensor_map[op.kernel_weights()]); auto bias_weights = TensorToWeights(value_to_tensor_map[op.bias_weights()]); int out_channel_num = op.out_channel_num(); - auto* layer = - network->addFullyConnected(*value_to_trt_tensor_map[input_tensor_repr], - out_channel_num, - kernel_weights, - bias_weights); + auto* layer = network->addFullyConnected( + *reshape_itensor, out_channel_num, kernel_weights, bias_weights); + + // TODO(wilber): fix. + nvinfer1::Dims reshape_after_fc_dim; + reshape_after_fc_dim.nbDims = 1; + reshape_after_fc_dim.d[0] = layer->getOutput(0)->getDimensions().d[0]; + auto* reshape_after_fc_layer = network->addShuffle(*layer->getOutput(0)); + reshape_after_fc_layer->setReshapeDimensions(reshape_after_fc_dim); mlir::Value out_repr = op.output_tensor(); - nvinfer1::ITensor* out_tensor = layer->getOutput(0); + nvinfer1::ITensor* out_tensor = reshape_after_fc_layer->getOutput(0); value_to_trt_tensor_map[out_repr] = out_tensor; } @@ -159,14 +213,12 @@ inline void ShuffleFunc(trt::ShuffleOp& op, // NOLINT mlir::Value input_tensor_repr = op.input_tensor(); nvinfer1::ITensor* input = value_to_trt_tensor_map[input_tensor_repr]; int dims = input->getDimensions().nbDims; - - int start_axis = op.start_axisAttr().getInt(); - int stop_axis = op.start_axisAttr().getInt(); + int start_axis = op.start_axis(); + int stop_axis = op.stop_axis(); nvinfer1::IShuffleLayer* layer = nullptr; if (start_axis < 0) start_axis += dims + 1; if (stop_axis < 0) stop_axis += dims + 1; - int dim_prod = 1; nvinfer1::Dims flatten_dim; flatten_dim.nbDims = dims - (stop_axis - start_axis); @@ -185,7 +237,6 @@ inline void ShuffleFunc(trt::ShuffleOp& op, // NOLINT layer = network->addShuffle(*value_to_trt_tensor_map[input_tensor_repr]); CHECK_NOTNULL(layer); layer->setReshapeDimensions(flatten_dim); - for (size_t i = 0; i < op->getNumResults(); ++i) { nvinfer1::ITensor* out_tensor = layer->getOutput(i); mlir::Value out_value = op->getResult(i); @@ -222,6 +273,30 @@ inline void ScaleNdFunc(trt::ScaleNdOp& op, // NOLINT value_to_trt_tensor_map[out_value] = out_tensor; } } + +inline void EltwiseFunc(trt::ElementWiseOp& op, // NOLINT + nvinfer1::INetworkDefinition* network, + ValueToITensorMap& value_to_trt_tensor_map, // NOLINT + ValueToTensorMap& value_to_tensor_map) { // NOLINT + mlir::Value input1_tensor_repr = op.input1(); + mlir::Value input2_tensor_repr = op.input2(); + nvinfer1::ITensor* input1 = value_to_trt_tensor_map[input1_tensor_repr]; + nvinfer1::ITensor* input2 = value_to_trt_tensor_map[input2_tensor_repr]; + + auto eltwise_operation = op.elementwise_operation(); + + auto* layer = network->addElementWise( + *input1, + *input2, + static_cast(eltwise_operation)); + CHECK_NOTNULL(layer); + for (size_t i = 0; i < op->getNumResults(); ++i) { + nvinfer1::ITensor* out_tensor = layer->getOutput(i); + mlir::Value out_value = op->getResult(i); + value_to_trt_tensor_map[out_value] = out_tensor; + } +} + } // namespace tensorrt } // namespace kernel } // namespace infrt -- GitLab