From 07933116a1cab5deaa34d94b0ec9c67e9e9f4c04 Mon Sep 17 00:00:00 2001 From: weishengying <63448337+weishengying@users.noreply.github.com> Date: Thu, 15 Sep 2022 19:46:09 +0800 Subject: [PATCH] General Plugin Mechanism (#45355) (#46070) --- .../fluid/inference/api/analysis_predictor.cc | 2 + .../fluid/inference/tensorrt/CMakeLists.txt | 14 +- .../inference/tensorrt/convert/CMakeLists.txt | 15 +- .../generic_and_custom_plugin_creater.cc | 248 ++++++++ .../inference/tensorrt/convert/op_converter.h | 219 ++++--- .../tensorrt/convert/test_custom_op_plugin.h | 356 +++++++++++ .../convert/test_custom_plugin_creater.cc | 209 +++++++ .../tensorrt/convert/test_op_converter.cc | 1 + .../tensorrt/dynamic_shape_infermeta.cc | 60 ++ .../dynamic_shape_infermeta_factory.h | 99 +++ .../dynamic_shape_infermeta_registry.h | 26 + paddle/fluid/inference/tensorrt/op_teller.cc | 581 ++++++++++-------- paddle/fluid/inference/tensorrt/op_teller.h | 29 +- .../inference/tensorrt/plugin/CMakeLists.txt | 11 +- .../tensorrt/plugin/generic_plugin.cu | 463 ++++++++++++++ .../tensorrt/plugin/generic_plugin.h | 162 +++++ .../tensorrt/plugin/mish_op_plugin.h | 9 +- .../tensorrt/plugin_arg_mapping_context.cc | 122 ++++ .../tensorrt/plugin_arg_mapping_context.h | 62 ++ .../tensorrt/test_arg_mapping_context.cc | 132 ++++ .../tensorrt/tensorrt_engine_op_test.cc | 2 + .../inference/test_trt_convert_gather_nd.py | 134 +++- .../ir/inference/test_trt_convert_yolo_box.py | 18 +- 23 files changed, 2579 insertions(+), 395 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/generic_and_custom_plugin_creater.cc create mode 100644 paddle/fluid/inference/tensorrt/convert/test_custom_op_plugin.h create mode 100644 paddle/fluid/inference/tensorrt/convert/test_custom_plugin_creater.cc create mode 100644 paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc create mode 100644 paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_factory.h create mode 100644 paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_registry.h create mode 100644 paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/generic_plugin.h create mode 100644 paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc create mode 100644 paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h create mode 100644 paddle/fluid/inference/tensorrt/test_arg_mapping_context.cc diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index fbc2830aff6..445145dde39 100755 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2185,6 +2185,8 @@ USE_TRT_CONVERTER(shape) USE_TRT_CONVERTER(fill_constant) USE_TRT_CONVERTER(fused_token_prune) USE_TRT_CONVERTER(layernorm_shift_partition) +USE_TRT_CONVERTER(generic_plugin_creater) +USE_TRT_CONVERTER(custom_plugin_creater) #if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) USE_TRT_CONVERTER(sparse_fc) USE_TRT_CONVERTER(sparse_multihead_matmul) diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt index 7239b506d33..d4a4c8c06af 100644 --- a/paddle/fluid/inference/tensorrt/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -12,10 +12,18 @@ else() SRCS engine.cc trt_int8_calibrator.cc DEPS ${GLOB_OPERATOR_DEPS} framework_proto device_context) endif() +nv_library( + tensorrt_dynamic_shape_infermeta_factory + SRCS dynamic_shape_infermeta.cc + DEPS framework_proto) +nv_library( + tensorrt_plugin_arg_mapping_context + SRCS plugin_arg_mapping_context.cc + DEPS framework_proto) nv_library( tensorrt_op_teller SRCS op_teller.cc - DEPS framework_proto device_context) + DEPS framework_proto device_context tensorrt_dynamic_shape_infermeta_factory) nv_test( test_tensorrt SRCS test_tensorrt.cc @@ -24,6 +32,10 @@ nv_test( test_tensorrt_engine SRCS test_engine.cc test_dynamic_engine.cc DEPS dynload_cuda tensorrt_engine tensorrt_plugin) +nv_test( + test_arg_mapping_context + SRCS test_arg_mapping_context.cc + DEPS framework_proto tensorrt_plugin_arg_mapping_context) if(WITH_ONNXRUNTIME AND WIN32) # Copy onnxruntime for some c++ test in Windows, since the test will diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index ce95363b72d..60a5d0f2825 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -76,7 +76,8 @@ list( shape_op.cc fill_constant_op.cc fused_token_prune_op.cc - layernorm_shift_partition_op.cc) + layernorm_shift_partition_op.cc + generic_and_custom_plugin_creater.cc) if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) list(APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc) @@ -85,7 +86,12 @@ endif() nv_library( tensorrt_converter SRCS ${CONVERT_FILES} - DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto + DEPS tensorrt_engine + tensorrt_plugin + operator + scope + framework_proto + tensorrt_op_teller op_registry) nv_test( @@ -94,6 +100,11 @@ nv_test( DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_converter) +nv_test( + test_custom_plugin_creater + SRCS test_custom_plugin_creater.cc + DEPS paddle_framework tensorrt_converter op_meta_info custom_operator) + if(WITH_ONNXRUNTIME AND WIN32) # Copy onnxruntime for some c++ test in Windows, since the test will # be build only in CI, so suppose the generator in Windows is Ninja. diff --git a/paddle/fluid/inference/tensorrt/convert/generic_and_custom_plugin_creater.cc b/paddle/fluid/inference/tensorrt/convert/generic_and_custom_plugin_creater.cc new file mode 100644 index 00000000000..e1ce9ceb020 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/generic_and_custom_plugin_creater.cc @@ -0,0 +1,248 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_meta_info_helper.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/helper.h" +#include "paddle/fluid/inference/tensorrt/plugin/generic_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +/* + * Stack converter from fluid to tensorRT. + */ +class CustomPluginCreater : public OpConverter { + public: + void operator()(const framework::proto::OpDesc &op, + const framework::Scope &scope, + bool test_mode) override { + framework::OpDesc op_desc(op, nullptr); + VLOG(3) << "convert " << op_desc.Type() << " op to custom pluign layer"; + + std::string plugin_name; + + if (engine_->with_dynamic_shape()) { + plugin_name = op_desc.Type() + "_paddle_trt_dynamic_plugin"; + } else { + plugin_name = op_desc.Type() + "_paddle_trt_plugin"; + } + + nvinfer1::ILayer *layer = nullptr; + std::vector inputs; + + auto &op_meta_info_map = OpMetaInfoMap::Instance(); + const auto &meta_info_map = op_meta_info_map.GetMap(); + auto &op_info = meta_info_map.at(op_desc.Type()).front(); + + // set inputs + auto &op_input_names = framework::OpMetaInfoHelper::GetInputs(op_info); + for (auto ¶m_name : op_input_names) { + for (auto &arg_name : op_desc.Input(param_name)) { + framework::Variable *X_v = nullptr; + X_v = scope.FindVar(arg_name); + // If this weight is not shared between ops, it need to be convtered to + // itensor + if (X_v && !engine_->GetITensorMap()->count(arg_name)) { + ConvertWeight2ITensor(scope, arg_name); + } + inputs.push_back(engine_->GetITensor(arg_name)); + } + } + auto creator = + GetPluginRegistry()->getPluginCreator(plugin_name.c_str(), "1"); + CHECK(creator); + + // set attrs + std::vector plugindatas; + auto &op_attrs_names = framework::OpMetaInfoHelper::GetAttrs(op_info); + auto &attrs = op_desc.GetAttrMap(); + + std::list int_attrs; + std::list float_attrs; + std::list bool_attrs; + std::list string_attrs; + std::list> ints_attrs; + std::list> floats_attrs; + + for (auto &attr_name : op_attrs_names) { + nvinfer1::PluginField plugindata; + plugindata.name = attr_name.c_str(); + if (op_desc.GetAttrType(attr_name) == framework::proto::AttrType::INT) { + int_attrs.push_back(PADDLE_GET_CONST(int, attrs.at(attr_name))); + plugindata.data = &int_attrs.back(); + plugindata.type = nvinfer1::PluginFieldType::kINT32; + plugindata.length = 1; + } else if (op_desc.GetAttrType(attr_name) == + framework::proto::AttrType::FLOAT) { + float_attrs.push_back(PADDLE_GET_CONST(float, attrs.at(attr_name))); + plugindata.data = &float_attrs.back(); + plugindata.type = nvinfer1::PluginFieldType::kFLOAT32; + plugindata.length = 1; + } else if (op_desc.GetAttrType(attr_name) == + framework::proto::AttrType::BOOLEAN) { + int_attrs.push_back(PADDLE_GET_CONST(bool, attrs.at(attr_name))); + plugindata.data = &int_attrs.back(); + plugindata.type = nvinfer1::PluginFieldType::kINT32; + plugindata.length = 1; + } else if (op_desc.GetAttrType(attr_name) == + framework::proto::AttrType::STRING) { + string_attrs.push_back( + PADDLE_GET_CONST(std::string, attrs.at(attr_name))); + plugindata.data = string_attrs.back().data(); + plugindata.type = nvinfer1::PluginFieldType::kCHAR; + plugindata.length = + string_attrs.back().size() + 1; // string ends with ‘\0’ + } else if (op_desc.GetAttrType(attr_name) == + framework::proto::AttrType::INTS) { + ints_attrs.push_back( + PADDLE_GET_CONST(std::vector, attrs.at(attr_name))); + plugindata.data = ints_attrs.back().data(); + plugindata.type = nvinfer1::PluginFieldType::kINT32; + plugindata.length = ints_attrs.back().size(); + } else if (op_desc.GetAttrType(attr_name) == + framework::proto::AttrType::FLOATS) { + floats_attrs.push_back( + PADDLE_GET_CONST(std::vector, attrs.at(attr_name))); + plugindata.data = floats_attrs.back().data(); + plugindata.type = nvinfer1::PluginFieldType::kFLOAT32; + plugindata.length = floats_attrs.back().size(); + } else if (op_desc.GetAttrType(attr_name) == + framework::proto::AttrType::BOOLEANS) { + auto bools_attr = + PADDLE_GET_CONST(std::vector, attrs.at(attr_name)); + std::vector convert_to_ints_attr; + for (bool i : bools_attr) convert_to_ints_attr.push_back(i); + ints_attrs.push_back(convert_to_ints_attr); + plugindata.data = ints_attrs.back().data(); + plugindata.type = nvinfer1::PluginFieldType::kINT32; + plugindata.length = ints_attrs.back().size(); + } else { + CHECK(false) << "UNKNOWN PluginFieldType."; + } + plugindatas.push_back(plugindata); + } + + nvinfer1::PluginFieldCollection plugin_fc{(int32_t)plugindatas.size(), + plugindatas.data()}; + + auto *plugin = creator->createPlugin(op_desc.Type().c_str(), &plugin_fc); + CHECK(plugin); + + if (engine_->with_dynamic_shape()) { + layer = + engine_->AddDynamicPlugin(inputs.data(), + inputs.size(), + (plugin::DynamicPluginTensorRT *)plugin); + } else { + layer = engine_->AddPlugin( + inputs.data(), inputs.size(), (plugin::PluginTensorRT *)plugin); + } + + CHECK(layer); + + // set outputs + auto &op_output_names = framework::OpMetaInfoHelper::GetOutputs(op_info); + std::vector output_names; + for (auto ¶m_name : op_output_names) { + for (auto &arg_name : op_desc.Output(param_name)) + output_names.push_back(arg_name); + } + + RreplenishLayerAndOutput(layer, op_desc.Type(), output_names, test_mode); + } +}; + +class GenericPluginCreater : public OpConverter { + public: + void operator()(const framework::proto::OpDesc &op, + const framework::Scope &scope, + bool test_mode) override { + framework::OpDesc op_desc(op, nullptr); + CHECK(block_); + const framework::BlockDesc block_desc( + nullptr, const_cast(block_)); + + nvinfer1::ILayer *layer = nullptr; + std::vector inputs; + + phi::KernelSignature phi_kernel_signature; + if (phi::OpUtilsMap::Instance().HasArgumentMappingFn(op_desc.Type())) { + const phi::ArgumentMappingFn *argument_mapping_func = + phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_desc.Type()); + PluginArgumentMappingContext argument_mapping_context(&op_desc); + phi_kernel_signature = (*argument_mapping_func)(argument_mapping_context); + } else { + phi_kernel_signature = + phi::DefaultKernelSignatureMap::Instance().Get(op_desc.Type()); + } + + plugin::GenericPlugin::InputOutPutVarInfo in_out_info; + + for (auto ¶m_name : phi_kernel_signature.input_names) { + for (auto &arg_name : op_desc.Input(param_name)) { + framework::Variable *X_v = nullptr; + X_v = scope.FindVar(arg_name); + // If this weight is not shared between ops, it need to be convtered to + // itensor + if (X_v && !engine_->GetITensorMap()->count(arg_name)) { + ConvertWeight2ITensor(scope, arg_name); + } + + inputs.push_back(engine_->GetITensor(arg_name)); + auto *var = block_desc.FindVar(arg_name); + PADDLE_ENFORCE_NOT_NULL( + var, + platform::errors::NotFound( + "There is no variable called %s in block.", arg_name.c_str())); + PADDLE_ENFORCE_EQ( + var->GetType(), + FluidDT::VarType_Type_LOD_TENSOR, + platform::errors::InvalidArgument("TensorRT engine only takes " + "LoDTensor as input")); + in_out_info.inputs_data_type.push_back(var->GetDataType()); + } + } + + std::vector output_names; + for (auto ¶m_name : phi_kernel_signature.output_names) { + for (auto &arg_name : op_desc.Output(param_name)) { + output_names.push_back(arg_name); + auto *var = block_desc.FindVar(arg_name); + PADDLE_ENFORCE_NOT_NULL( + var, + platform::errors::NotFound( + "There is no variable called %s in block.", arg_name.c_str())); + PADDLE_ENFORCE_EQ( + var->GetType(), + FluidDT::VarType_Type_LOD_TENSOR, + platform::errors::InvalidArgument("TensorRT engine only takes " + "LoDTensor as input")); + in_out_info.outputs_data_type.push_back(var->GetDataType()); + } + } + plugin::GenericPlugin *plugin = new plugin::GenericPlugin(op, in_out_info); + layer = engine_->AddDynamicPlugin(inputs.data(), inputs.size(), plugin); + + RreplenishLayerAndOutput(layer, op_desc.Type(), output_names, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(custom_plugin_creater, CustomPluginCreater); +REGISTER_TRT_OP_CONVERTER(generic_plugin_creater, GenericPluginCreater); diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index cdd6345c484..095457dbfbb 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/helper.h" +#include "paddle/fluid/inference/tensorrt/op_teller.h" #include "paddle/fluid/inference/utils/singleton.h" namespace paddle { @@ -49,111 +50,135 @@ class OpConverter { const std::unordered_set& parameters, const framework::Scope& scope, TensorRTEngine* engine, - bool test_mode = false) { + bool test_mode = false, + const framework::proto::BlockDesc* block = nullptr) { framework::OpDesc op_desc(op, nullptr); OpConverter* it{nullptr}; - if (op_desc.Type() == "mul") { - PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), - 1UL, - platform::errors::InvalidArgument( - "The input op mul's Input(\"Y\")." - "size() should equal to 1, but reveceid " - "Input(\"Y\").size() = %u.", - op_desc.Input("Y").size())); - std::string Y = op_desc.Input("Y")[0]; - if (parameters.count(Y)) { - it = Registry::Global().Lookup("fc"); - } - } - if (op_desc.Type().find("elementwise") != std::string::npos) { - static std::unordered_set add_tensor_op_set{ - "add", "mul", "sub", "div", "max", "min", "pow"}; - static std::unordered_set add_weight_op_set{ - "add", "mul", "sub", "div", "pow"}; - PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), - 1UL, - platform::errors::InvalidArgument( - "The input op's Input(\"Y\")." - "size() should equal to 1, but reveceid " - "Input(\"Y\").size() = %u.", - op_desc.Input("Y").size())); - int op_type_len = op_desc.Type().size(); - std::string op_type = op_desc.Type().substr(op_type_len - 3, op_type_len); - std::string Y = op_desc.Input("Y")[0]; - if (parameters.count(Y)) { - PADDLE_ENFORCE_GT( - add_weight_op_set.count(op_type), - 0, - platform::errors::Unimplemented("Unsupported elementwise type %s", - op_type.c_str())); - it = Registry::Global().Lookup("elementwise_" + op_type + - "_weight"); - PADDLE_ENFORCE_NOT_NULL( - it, - platform::errors::Unimplemented("no OpConverter for optype [%s]", - op_desc.Type())); - } else { - PADDLE_ENFORCE_GT( - add_tensor_op_set.count(op_type), - 0, - platform::errors::Unimplemented("Unsupported elementwise type %s", - op_type.c_str())); - it = Registry::Global().Lookup("elementwise_" + op_type + - "_tensor"); - } - PADDLE_ENFORCE_NOT_NULL( - it, - platform::errors::Unimplemented("no OpConverter for optype [%s]", - op_desc.Type())); - } + auto op_converter_type_map = OpTeller::Global().GetOpConverterTypeMap(); + switch (op_converter_type_map.at(op_desc.Type())) { + case OpConverterType::Default: + if (op_desc.Type() == "mul") { + PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), + 1UL, + platform::errors::InvalidArgument( + "The input op mul's Input(\"Y\")." + "size() should equal to 1, but reveceid " + "Input(\"Y\").size() = %u.", + op_desc.Input("Y").size())); + std::string Y = op_desc.Input("Y")[0]; + if (parameters.count(Y)) { + it = Registry::Global().Lookup("fc"); + } + } + if (op_desc.Type().find("elementwise") != std::string::npos) { + static std::unordered_set add_tensor_op_set{ + "add", "mul", "sub", "div", "max", "min", "pow"}; + static std::unordered_set add_weight_op_set{ + "add", "mul", "sub", "div", "pow"}; + PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), + 1UL, + platform::errors::InvalidArgument( + "The input op's Input(\"Y\")." + "size() should equal to 1, but reveceid " + "Input(\"Y\").size() = %u.", + op_desc.Input("Y").size())); + int op_type_len = op_desc.Type().size(); + std::string op_type = + op_desc.Type().substr(op_type_len - 3, op_type_len); + std::string Y = op_desc.Input("Y")[0]; + if (parameters.count(Y)) { + PADDLE_ENFORCE_GT( + add_weight_op_set.count(op_type), + 0, + platform::errors::Unimplemented( + "Unsupported elementwise type %s", op_type.c_str())); + it = Registry::Global().Lookup("elementwise_" + + op_type + "_weight"); + PADDLE_ENFORCE_NOT_NULL( + it, + platform::errors::Unimplemented( + "no OpConverter for optype [%s]", op_desc.Type())); + } else { + PADDLE_ENFORCE_GT( + add_tensor_op_set.count(op_type), + 0, + platform::errors::Unimplemented( + "Unsupported elementwise type %s", op_type.c_str())); + it = Registry::Global().Lookup("elementwise_" + + op_type + "_tensor"); + } + PADDLE_ENFORCE_NOT_NULL( + it, + platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); + } - if (op_desc.Type() == "depthwise_conv2d") { - it = Registry::Global().Lookup("conv2d"); - PADDLE_ENFORCE_NOT_NULL( - it, - platform::errors::Unimplemented("no OpConverter for optype [%s]", - op_desc.Type())); - } - if (op_desc.Type() == "depthwise_conv2d_transpose") { - it = Registry::Global().Lookup("conv2d_transpose"); - PADDLE_ENFORCE_NOT_NULL( - it, - platform::errors::Unimplemented("no OpConverter for optype [%s]", - op_desc.Type())); - } - if (op_desc.Type() == "transpose2") { - it = Registry::Global().Lookup("transpose"); - PADDLE_ENFORCE_NOT_NULL( - it, - platform::errors::Unimplemented("no OpConverter for optype [%s]", - op_desc.Type())); - } - if (op_desc.Type() == "flatten2") { - it = Registry::Global().Lookup("flatten"); - PADDLE_ENFORCE_NOT_NULL( - it, - platform::errors::Unimplemented("no OpConverter for optype [%s]", - op_desc.Type())); - } - // reshape2 == reshape - if (op_desc.Type() == "reshape2") { - it = Registry::Global().Lookup("reshape"); - PADDLE_ENFORCE_NOT_NULL( - it, - platform::errors::Unimplemented("no OpConverter for optype [%s]", - op_desc.Type())); - } - if (!it) { - it = Registry::Global().Lookup(op_desc.Type()); + if (op_desc.Type() == "depthwise_conv2d") { + it = Registry::Global().Lookup("conv2d"); + PADDLE_ENFORCE_NOT_NULL( + it, + platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); + } + if (op_desc.Type() == "depthwise_conv2d_transpose") { + it = Registry::Global().Lookup("conv2d_transpose"); + PADDLE_ENFORCE_NOT_NULL( + it, + platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); + } + if (op_desc.Type() == "transpose2") { + it = Registry::Global().Lookup("transpose"); + PADDLE_ENFORCE_NOT_NULL( + it, + platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); + } + if (op_desc.Type() == "flatten2") { + it = Registry::Global().Lookup("flatten"); + PADDLE_ENFORCE_NOT_NULL( + it, + platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); + } + // reshape2 == reshape + if (op_desc.Type() == "reshape2") { + it = Registry::Global().Lookup("reshape"); + PADDLE_ENFORCE_NOT_NULL( + it, + platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); + } + if (!it) { + it = Registry::Global().Lookup(op_desc.Type()); + } + break; + + case OpConverterType::GenericPluginCreater: + LOG(INFO) << "There is no OpConverter for type " << op_desc.Type() + << ", now use generic_plugin_creater!"; + it = Registry::Global().Lookup("generic_plugin_creater"); + break; + + case OpConverterType::CustomPluginCreater: + LOG(INFO) << "There is no OpConverter for type " << op_desc.Type() + << ", now use custom_plugin_creater!"; + it = Registry::Global().Lookup("custom_plugin_creater"); + break; + + default: + CHECK(false) << "no OpConverter for optype " << op_desc.Type(); } + PADDLE_ENFORCE_NOT_NULL( it, platform::errors::Unimplemented("no OpConverter for optype [%s]", op_desc.Type())); it->SetEngine(engine); + it->SetBlockDesc(block); (*it)(op, scope, test_mode); size_t output_num = op_desc.OutputNames().size(); @@ -257,7 +282,7 @@ class OpConverter { } for (int i = 0; i < block.ops_size(); i++) { const auto& op = block.ops(i); - ConvertOp(op, parameters, scope, engine); + ConvertOp(op, parameters, scope, engine, false, &block); } for (int i = 0; i < engine->network()->getNbLayers(); i++) { auto layer = engine->network()->getLayer(i); @@ -620,10 +645,16 @@ class OpConverter { } void SetEngine(TensorRTEngine* engine) { engine_ = engine; } + void SetBlockDesc(const framework::proto::BlockDesc* block) { + block_ = block; + } + virtual ~OpConverter() {} // TensorRT engine TensorRTEngine* engine_{nullptr}; + // BlockDesc + const framework::proto::BlockDesc* block_{nullptr}; protected: bool test_mode_; diff --git a/paddle/fluid/inference/tensorrt/convert/test_custom_op_plugin.h b/paddle/fluid/inference/tensorrt/convert/test_custom_op_plugin.h new file mode 100644 index 00000000000..adb41528bae --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_custom_op_plugin.h @@ -0,0 +1,356 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class custom_op_plugin : public nvinfer1::IPluginV2 { + public: + explicit custom_op_plugin(float float_attr) { float_attr_ = float_attr; } + + custom_op_plugin(const void* buffer, size_t length) { + DeserializeValue(&buffer, &length, &float_attr_); + } + + size_t getSerializationSize() const noexcept override { + return SerializedSize(float_attr_); + } + + void serialize(void* buffer) const noexcept override { + SerializeValue(&buffer, float_attr_); + } + + nvinfer1::IPluginV2* clone() const noexcept override { + return new custom_op_plugin(float_attr_); + } + + ~custom_op_plugin() override = default; + + const char* getPluginType() const noexcept override { + return "custom_op_paddle_trt_plugin"; + } + + const char* getPluginVersion() const noexcept override { return "1"; } + + int getNbOutputs() const noexcept override { return 1; } + + nvinfer1::Dims getOutputDimensions(int index, + const nvinfer1::Dims* inputs, + int nbInputDims) noexcept override { + return inputs[0]; + } + + bool supportsFormat(nvinfer1::DataType type, + nvinfer1::PluginFormat format) const noexcept override { + return true; + } + + void configureWithFormat(nvinfer1::Dims const* inputDims, + int32_t nbInputs, + nvinfer1::Dims const* outputDims, + int32_t nbOutputs, + nvinfer1::DataType type, + nvinfer1::PluginFormat format, + int32_t maxBatchSize) noexcept override {} + + int initialize() noexcept override { return 0; } + + void terminate() noexcept override {} + + size_t getWorkspaceSize(int maxBatchSize) const noexcept override { + return 0; + } + +#if IS_TRT_VERSION_LT(8000) + int enqueue(int batch_size, + const void* const* inputs, + void** outputs, +#else + int enqueue(int batch_size, + const void* const* inputs, + void* const* outputs, +#endif + void* workspace, + cudaStream_t stream) noexcept override { + return 0; + } + + void destroy() noexcept override { delete this; } + + void setPluginNamespace(const char* libNamespace) noexcept override { + namespace_ = libNamespace; + } + + const char* getPluginNamespace() const noexcept override { + return namespace_.c_str(); + } + + private: + float float_attr_; + std::string namespace_; +}; + +class custom_op_plugin_creator : public nvinfer1::IPluginCreator { + public: + custom_op_plugin_creator() {} + + ~custom_op_plugin_creator() override = default; + + const char* getPluginName() const noexcept override { + return "custom_op_paddle_trt_plugin"; + } + + const char* getPluginVersion() const noexcept override { return "1"; } + + void setPluginNamespace(const char* pluginNamespace) noexcept override { + plugin_namespace_ = pluginNamespace; + } + + const char* getPluginNamespace() const noexcept override { + return plugin_namespace_.c_str(); + } + + const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override { + return nullptr; + } + + nvinfer1::IPluginV2* createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) noexcept override { + CHECK_EQ(fc->nbFields, 7); + // float_attr + auto attr_field = (fc->fields)[0]; + CHECK(attr_field.type == nvinfer1::PluginFieldType::kFLOAT32); + CHECK_EQ(attr_field.length, 1); + float float_value = (reinterpret_cast(attr_field.data))[0]; + CHECK_EQ(float_value, 1.0); + + // int_attr + attr_field = (fc->fields)[1]; + CHECK(attr_field.type == nvinfer1::PluginFieldType::kINT32); + CHECK_EQ(attr_field.length, 1); + int int_value = (reinterpret_cast(attr_field.data))[0]; + CHECK_EQ(int_value, 1); + + // bool_attr + attr_field = (fc->fields)[2]; + CHECK(attr_field.type == nvinfer1::PluginFieldType::kINT32); + CHECK_EQ(attr_field.length, 1); + int bool_value = (reinterpret_cast(attr_field.data))[0]; + CHECK_EQ(bool_value, 1); + + // string_attr + attr_field = (fc->fields)[3]; + CHECK(attr_field.type == nvinfer1::PluginFieldType::kCHAR); + std::string expect_string_attr = "test_string_attr"; + CHECK_EQ((size_t)attr_field.length, expect_string_attr.size() + 1); + const char* receive_string_attr = + reinterpret_cast(attr_field.data); + CHECK(expect_string_attr == std::string(receive_string_attr)); + + // ints_attr + attr_field = (fc->fields)[4]; + CHECK(attr_field.type == nvinfer1::PluginFieldType::kINT32); + CHECK_EQ(attr_field.length, 3); + const int* ints_value = reinterpret_cast(attr_field.data); + CHECK_EQ(ints_value[0], 1); + CHECK_EQ(ints_value[1], 2); + CHECK_EQ(ints_value[2], 3); + + // floats_attr + attr_field = (fc->fields)[5]; + CHECK(attr_field.type == nvinfer1::PluginFieldType::kFLOAT32); + CHECK_EQ(attr_field.length, 3); + const float* floats_value = reinterpret_cast(attr_field.data); + CHECK_EQ(floats_value[0], 1.0); + CHECK_EQ(floats_value[1], 2.0); + CHECK_EQ(floats_value[2], 3.0); + + // bools_attr + attr_field = (fc->fields)[6]; + CHECK(attr_field.type == nvinfer1::PluginFieldType::kINT32); + CHECK_EQ(attr_field.length, 3); + ints_value = reinterpret_cast(attr_field.data); + CHECK_EQ(ints_value[0], true); + CHECK_EQ(ints_value[1], false); + CHECK_EQ(ints_value[2], true); + + return new custom_op_plugin(float_value); + } + + nvinfer1::IPluginV2* deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) noexcept override { + return new custom_op_plugin(serialData, serialLength); + } + + private: + std::string plugin_namespace_; +}; + +class custom_op_dynamic_plugin : public nvinfer1::IPluginV2DynamicExt { + public: + explicit custom_op_dynamic_plugin(float float_attr) + : float_attr_(float_attr) {} + + custom_op_dynamic_plugin(const void* buffer, size_t length) { + DeserializeValue(&buffer, &length, &float_attr_); + } + + ~custom_op_dynamic_plugin() override = default; + + const char* getPluginType() const noexcept override { + return "custom_op_paddle_trt_dynamic_plugin"; + } + + const char* getPluginVersion() const noexcept override { return "1"; } + + int getNbOutputs() const noexcept override { return 1; } + + int initialize() noexcept override { return 0; } + + void terminate() noexcept override {} + + size_t getSerializationSize() const noexcept override { + return SerializedSize(float_attr_); + } + + void serialize(void* buffer) const noexcept override { + SerializeValue(&buffer, float_attr_); + } + + void destroy() noexcept override { delete this; } + + void setPluginNamespace(const char* libNamespace) noexcept override { + namespace_ = libNamespace; + } + + const char* getPluginNamespace() const noexcept override { + return namespace_.c_str(); + } + + /*IPluginV2Ext method*/ + nvinfer1::DataType getOutputDataType( + int32_t index, + nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept override { + return inputTypes[index]; + } + + /*IPluginV2DynamicExt method*/ + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override { + return new custom_op_dynamic_plugin(float_attr_); + }; + + nvinfer1::DimsExprs getOutputDimensions( + int32_t outputIndex, + const nvinfer1::DimsExprs* inputs, + int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override { + return inputs[0]; + } + + bool supportsFormatCombination(int32_t pos, + const nvinfer1::PluginTensorDesc* inOut, + int32_t nbInputs, + int32_t nbOutputs) noexcept override { + return true; + } + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int32_t nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int32_t nbOutputs) noexcept override {} + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int32_t nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int32_t nbOutputs) const noexcept override { + return 0; + } + + int32_t enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept override { + return 0; + } + + private: + float float_attr_ = 0; + std::string namespace_; +}; + +class custom_op_dynamic_plugin_creator : public nvinfer1::IPluginCreator { + public: + custom_op_dynamic_plugin_creator() {} + + ~custom_op_dynamic_plugin_creator() override = default; + + const char* getPluginName() const noexcept override { + return "custom_op_paddle_trt_dynamic_plugin"; + } + + const char* getPluginVersion() const noexcept override { return "1"; } + + void setPluginNamespace(char const* pluginNamespace) noexcept override { + plugin_namespace_ = pluginNamespace; + } + + const char* getPluginNamespace() const noexcept override { + return plugin_namespace_.c_str(); + } + + const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override { + return nullptr; + } + + nvinfer1::IPluginV2* createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) noexcept override { + return new custom_op_dynamic_plugin(1.0); + } + + nvinfer1::IPluginV2* deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) noexcept override { + return new custom_op_dynamic_plugin(serialData, serialLength); + } + + private: + std::string plugin_namespace_; +}; + +REGISTER_TRT_PLUGIN_V2(custom_op_plugin_creator); +REGISTER_TRT_PLUGIN_V2(custom_op_dynamic_plugin_creator); + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/convert/test_custom_plugin_creater.cc b/paddle/fluid/inference/tensorrt/convert/test_custom_plugin_creater.cc new file mode 100644 index 00000000000..2a3ead9c8e6 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_custom_plugin_creater.cc @@ -0,0 +1,209 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include // NOLINT + +#include "paddle/extension.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/test_custom_op_plugin.h" + +PD_BUILD_OP(custom_op) + .Inputs({"Input"}) + .Outputs({"Output"}) + .Attrs({ + "float_attr", + "int_attr", + "bool_attr", + "string_attr", + "ints_attr", + "floats_attr", + "bools_attr", + }); + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(CustomPluginCreater, StaticShapePlugin) { + framework::ProgramDesc prog; + auto *block = prog.MutableBlock(0); + auto *op = block->AppendOp(); + framework::proto::OpDesc *op_desc = op->Proto(); + + op_desc->set_type("custom_op"); + auto *input_var = op_desc->add_inputs(); + input_var->set_parameter("Input"); + *input_var->add_arguments() = "X"; + + auto *output_var = op_desc->add_outputs(); + output_var->set_parameter("Output"); + *output_var->add_arguments() = "Out"; + + auto *attr = op_desc->add_attrs(); + attr->set_name("float_attr"); + attr->set_type(paddle::framework::proto::AttrType::FLOAT); + attr->set_f(1.0); + + attr = op_desc->add_attrs(); + attr->set_name("int_attr"); + attr->set_type(paddle::framework::proto::AttrType::INT); + attr->set_i(1); + + attr = op_desc->add_attrs(); + attr->set_name("bool_attr"); + attr->set_type(paddle::framework::proto::AttrType::BOOLEAN); + attr->set_b(true); + + attr = op_desc->add_attrs(); + attr->set_name("string_attr"); + attr->set_type(paddle::framework::proto::AttrType::STRING); + attr->set_s("test_string_attr"); + + attr = op_desc->add_attrs(); + attr->set_name("ints_attr"); + attr->set_type(paddle::framework::proto::AttrType::INTS); + attr->add_ints(1); + attr->add_ints(2); + attr->add_ints(3); + + attr = op_desc->add_attrs(); + attr->set_name("floats_attr"); + attr->set_type(paddle::framework::proto::AttrType::FLOATS); + attr->add_floats(1.0); + attr->add_floats(2.0); + attr->add_floats(3.0); + + attr = op_desc->add_attrs(); + attr->set_name("bools_attr"); + attr->set_type(paddle::framework::proto::AttrType::BOOLEANS); + attr->add_bools(true); + attr->add_bools(false); + attr->add_bools(true); + + // init trt engine + std::unique_ptr engine_; + engine_.reset(new TensorRTEngine(5, 1 << 15)); + engine_->InitNetwork(); + + engine_->DeclareInput( + "X", nvinfer1::DataType::kFLOAT, nvinfer1::Dims3(2, 5, 5)); + + framework::Scope scope; + + tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt(); + + auto &custom_plugin_tell = OpTeller::Global().GetCustomPluginTeller(); + + framework::OpDesc custom_op(*op_desc, nullptr); + CHECK_EQ((*custom_plugin_tell)(custom_op, false, false), true); + + OpTeller::Global().SetOpConverterType("custom_op", + OpConverterType::CustomPluginCreater); + + OpConverter converter; + converter.ConvertBlock( + *block->Proto(), {}, scope, engine_.get() /*TensorRTEngine*/); +} + +TEST(CustomPluginCreater, DynamicShapePlugin) { + framework::ProgramDesc prog; + auto *block = prog.MutableBlock(0); + auto *op = block->AppendOp(); + framework::proto::OpDesc *op_desc = op->Proto(); + + op_desc->set_type("custom_op"); + auto *input_var = op_desc->add_inputs(); + input_var->set_parameter("Input"); + *input_var->add_arguments() = "X"; + + auto *output_var = op_desc->add_outputs(); + output_var->set_parameter("Output"); + *output_var->add_arguments() = "Out"; + + auto *attr = op_desc->add_attrs(); + attr->set_name("float_attr"); + attr->set_type(paddle::framework::proto::AttrType::FLOAT); + + attr = op_desc->add_attrs(); + attr->set_name("int_attr"); + attr->set_type(paddle::framework::proto::AttrType::INT); + + attr = op_desc->add_attrs(); + attr->set_name("bool_attr"); + attr->set_type(paddle::framework::proto::AttrType::BOOLEAN); + + attr = op_desc->add_attrs(); + attr->set_name("string_attr"); + attr->set_type(paddle::framework::proto::AttrType::STRING); + + attr = op_desc->add_attrs(); + attr->set_name("ints_attr"); + attr->set_type(paddle::framework::proto::AttrType::INTS); + + attr = op_desc->add_attrs(); + attr->set_name("floats_attr"); + attr->set_type(paddle::framework::proto::AttrType::FLOATS); + + attr = op_desc->add_attrs(); + attr->set_name("bools_attr"); + attr->set_type(paddle::framework::proto::AttrType::BOOLEANS); + + // init trt engine + std::unique_ptr engine_; + + std::map> min_input_shape = { + {"x", {1, 2, 5, 5}}}; + + std::map> max_input_shape = { + {"x", {1, 2, 5, 5}}}; + + std::map> optim_input_shape = { + {"x", {1, 2, 5, 5}}}; + + engine_.reset(new TensorRTEngine(5, + 1 << 15, + AnalysisConfig::Precision::kFloat32, + nullptr, + 0, + min_input_shape, + max_input_shape, + optim_input_shape)); + engine_->InitNetwork(); + + LOG(INFO) << "with_dynamic_shape " << engine_->with_dynamic_shape(); + engine_->DeclareInput( + "X", nvinfer1::DataType::kFLOAT, nvinfer1::Dims4(-1, 2, 5, 5)); + + framework::Scope scope; + + tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt(); + + auto &custom_plugin_tell = OpTeller::Global().GetCustomPluginTeller(); + + framework::OpDesc custom_op(*op_desc, nullptr); + CHECK_EQ((*custom_plugin_tell)(custom_op, false, true), true); + + OpTeller::Global().SetOpConverterType("custom_op", + OpConverterType::CustomPluginCreater); + + OpConverter converter; + converter.ConvertBlock( + *block->Proto(), {}, scope, engine_.get() /*TensorRTEngine*/); +} +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_TRT_CONVERTER(custom_plugin_creater) diff --git a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc index 5e748aad237..795f62a3e1e 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc @@ -57,6 +57,7 @@ TEST(OpConverter, ConvertBlock) { x_tensor->Resize(phi::make_ddim(dim_vec)); x_tensor->mutable_data(platform::CUDAPlace(0)); + OpTeller::Global().SetOpConverterType("conv2d", OpConverterType::Default); OpConverter converter; converter.ConvertBlock( *block->Proto(), {"conv2d-Y"}, scope, engine_.get() /*TensorRTEngine*/); diff --git a/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc b/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc new file mode 100644 index 00000000000..1d75f0a7fbf --- /dev/null +++ b/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_factory.h" +#include "paddle/fluid/inference/tensorrt/helper.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/kernels/funcs/unfold_functor.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +nvinfer1::DimsExprs GatherNdInferMeta( + int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder, // NOLINT + const framework::OpDesc& op_desc) { + const nvinfer1::DimsExprs x_dims = inputs[0]; + const int x_dims_size = inputs[0].nbDims; + const nvinfer1::DimsExprs index_dims = inputs[1]; + const int index_dims_size = inputs[1].nbDims; + + std::vector result_dims; + // The result dims is + // Index.shape[:-1] + X.shape[Index.shape[-1]:] + for (int i = 0; i < index_dims_size - 1; ++i) { + result_dims.emplace_back(index_dims.d[i]); + } + + if (index_dims.d[index_dims_size - 1]->isConstant()) { + for (int i = index_dims.d[index_dims_size - 1]->getConstantValue(); + i < x_dims_size; + ++i) { + result_dims.emplace_back(x_dims.d[i]); + } + } + + nvinfer1::DimsExprs output; + output.nbDims = result_dims.size(); + for (int i = 0; i < output.nbDims; i++) { + output.d[i] = result_dims[i]; + } + return output; +} +PD_REGISTER_DYNAMIC_INFER_META_FN(gather_nd, GatherNdInferMeta); +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_factory.h b/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_factory.h new file mode 100644 index 00000000000..0196d81754f --- /dev/null +++ b/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_factory.h @@ -0,0 +1,99 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/platform/macros.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/macros.h" +#include "paddle/utils/flat_hash_map.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +using DynamicMetaFn = + nvinfer1::DimsExprs (*)(int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder, // NOLINT + const framework::OpDesc& op_desc); + +class DynamicMetaFnFactory { + public: + static DynamicMetaFnFactory& Instance() { + static DynamicMetaFnFactory g_meta_fn_map; + return g_meta_fn_map; + } + + bool Contains(const std::string& op_name) const { + return meta_fn_map_.count(op_name) > 0; + } + + void Insert(std::string op_name, DynamicMetaFn infer_meta_fn) { + PADDLE_ENFORCE_NE( + Contains(op_name), + true, + phi::errors::AlreadyExists( + "`%s` op's DynamicInferMetaFn has been registered.", op_name)); + meta_fn_map_.insert({std::move(op_name), std::move(infer_meta_fn)}); + } + + const DynamicMetaFn& Get(const std::string& op_name) const { + auto it = meta_fn_map_.find(op_name); + PADDLE_ENFORCE_NE( + it, + meta_fn_map_.end(), + phi::errors::NotFound( + "`%s` op's DynamicInferMetaFn has been registered.", op_name)); + return it->second; + } + + private: + DynamicMetaFnFactory() = default; + + paddle::flat_hash_map meta_fn_map_; + + DISABLE_COPY_AND_ASSIGN(DynamicMetaFnFactory); +}; + +struct DynamicMetaFnRegistrar { + DynamicMetaFnRegistrar(const char* op_name, DynamicMetaFn infer_meta_fn) { + DynamicMetaFnFactory::Instance().Insert(op_name, std::move(infer_meta_fn)); + } + + static void Touch() {} +}; + +#define PD_REGISTER_DYNAMIC_INFER_META_FN(op_name, dynamic_infer_meta_fn) \ + static paddle::inference::tensorrt::DynamicMetaFnRegistrar \ + registrar_dynamic_infer_meta_fn_for_##op_name(#op_name, \ + dynamic_infer_meta_fn); \ + int TouchDynamicMetaFnRegistrar_##op_name() { \ + registrar_dynamic_infer_meta_fn_for_##op_name.Touch(); \ + return 0; \ + } + +#define USE_TRT_DYNAMIC_INFER_META_FN(op_name) \ + extern int TouchDynamicMetaFnRegistrar_##op_name(); \ + static int use_op_dynamic_infer_meta##op_name UNUSED = \ + TouchDynamicMetaFnRegistrar_##op_name(); + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_registry.h b/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_registry.h new file mode 100644 index 00000000000..f31040772c9 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_registry.h @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_factory.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +USE_TRT_DYNAMIC_INFER_META_FN(gather_nd); +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 70f290db034..55457aa5827 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -18,6 +18,11 @@ #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/framework/op_meta_info_helper.h" +#include "paddle/fluid/framework/phi_utils.h" +#include "paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_factory.h" +#include "paddle/phi/core/compat/op_utils.h" +#include "paddle/phi/core/kernel_factory.h" namespace paddle { namespace framework { @@ -60,252 +65,16 @@ struct SimpleOpTypeSetTeller : public Teller { #endif } - bool operator()(const std::string& op_type, - const framework::OpDesc& desc, - bool use_no_calib_int8) override { - if (use_no_calib_int8) { - return int8_teller_set.count(op_type); - } else { - return teller_set.count(op_type); - } - } - - private: - // use this set for no calib int8. - std::unordered_set int8_teller_set{ - "mul", - "matmul", - "conv2d", - "conv2d_fusion", - "pool2d", - "relu", - "elu", - "selu", - "softsign", - "softplus", - "stanh", - "thresholded_relu", - "exp", - "log", - "sqrt", - "abs", - "sin", - "cos", - "tan", - "sinh", - "cosh", - "asin", - "acos", - "atan", - "asinh", - "atanh", - "ceil", - "floor", - "erf", - "softmax", - "sigmoid", - "hard_swish", - "depthwise_conv2d", - "batch_norm", - "concat", - "tanh", - "pad", - "elementwise_add", - "elementwise_sub", - "elementwise_mul", - "elementwise_div", - "elementwise_pow", - "equal", - "dropout", - "prelu", - "conv2d_transpose", - "depthwise_conv2d_transpose", - "leaky_relu", - "fc", - "shuffle_channel", - "swish", - "silu", - "split", - "instance_norm", - "gelu", - "layer_norm", - "scale", - "stack", - "transpose2", - "transpose", - "top_k", - "top_k_v2", - "flatten2", - "flatten", - "gather", - "gather_nd", - "yolo_box", - "yolo_box_head", - "arg_max", - "roi_align", - "affine_channel", - "nearest_interp", - "anchor_generator", - "reduce_sum", - "reduce_mean", - "conv3d", - "conv3d_transpose", - "mish", - "nearest_interp_v2", - "bilinear_interp_v2", - "pool3d", - "deformable_conv", - "relu6", - "hard_sigmoid", - "clip", - "fused_embedding_eltwise_layernorm", - "multihead_matmul", - "skip_layernorm", - "slice", - "strided_slice", - "fused_preln_embedding_eltwise_layernorm", - "preln_residual_bias", - "c_allreduce_sum", - "c_allreduce_min", - "c_allreduce_max", - "c_allreduce_prod", - "roll", - "cast", - "preln_skip_layernorm", - "transformer_input_convert", - "recover_padding", - "remove_padding", - "fill_constant", - "sum", - "shape", - "squeeze2", - "unsqueeze2", - "layernorm_shift_partition"}; - std::unordered_set teller_set{ - "mul", - "matmul", - "conv2d", - "conv2d_fusion", - "pool2d", - "relu", - "elu", - "selu", - "softsign", - "softplus", - "stanh", - "thresholded_relu", - "exp", - "log", - "sqrt", - "abs", - "sin", - "cos", - "tan", - "sinh", - "cosh", - "asin", - "acos", - "atan", - "asinh", - "atanh", - "ceil", - "floor", - "erf", - "softmax", - "sigmoid", - "hard_swish", - "depthwise_conv2d", - "batch_norm", - "concat", - "tanh", - "pad", - "elementwise_add", - "elementwise_sub", - "elementwise_mul", - "elementwise_div", - "elementwise_pow", - "equal", - "dropout", - "prelu", - "conv2d_transpose", - "depthwise_conv2d_transpose", - "leaky_relu", - "fc", - "shuffle_channel", - "swish", - "silu", - "split", - "instance_norm", - "gelu", - "layer_norm", - "scale", - "stack", - "transpose2", - "transpose", - "top_k", - "top_k_v2", - "flatten2", - "flatten", - "gather", - "gather_nd", - "yolo_box", - "yolo_box_head", - "arg_max", - "roi_align", - "affine_channel", - "nearest_interp", - "anchor_generator", - "reduce_sum", - "reduce_mean", - "conv3d", - "conv3d_transpose", - "mish", - "bilinear_interp_v2", - "nearest_interp_v2", - "pool3d", - "deformable_conv", - "relu6", - "hard_sigmoid", - "clip", - "fused_embedding_eltwise_layernorm", - "multihead_matmul", - "skip_layernorm", - "slice", - "strided_slice", - "fused_preln_embedding_eltwise_layernorm", - "preln_skip_layernorm", - "preln_residual_bias", - "c_allreduce_sum", - "c_allreduce_min", - "c_allreduce_max", - "c_allreduce_prod", - "roll", - "cast", - "transformer_input_convert", - "recover_padding", - "remove_padding", - "fill_constant", - "sum", - "shape", - "squeeze2", - "unsqueeze2", - "fused_token_prune", - "layernorm_shift_partition"}; -}; - -bool OpTeller::Tell(const framework::ir::Node* node, - bool use_no_calib_int8, - bool with_dynamic_shape) { - const std::string op_type = node->Op()->Type(); - const framework::OpDesc desc = *node->Op(); - // do not support the op which is labeled the `skip_quant` - if ((desc.HasAttr("namescope") && - PADDLE_GET_CONST(std::string, desc.GetAttr("op_namescope")) == - "/skip_quant_2/") || - desc.HasAttr("skip_quant")) - return false; - - for (auto& teller : tellers_) { + bool operator()(const framework::OpDesc& desc, + bool use_no_calib_int8 = false, + bool with_dynamic_shape = false) override { + const std::string op_type = desc.Type(); + // do not support the op which is labeled the `skip_quant` + if ((desc.HasAttr("namescope") && + PADDLE_GET_CONST(std::string, desc.GetAttr("op_namescope")) == + "/skip_quant_2/") || + desc.HasAttr("skip_quant")) + return false; std::unordered_set act_op_list = { "relu", "relu6", "sigmoid", "elu", "selu", "softsign", @@ -2300,13 +2069,329 @@ bool OpTeller::Tell(const framework::ir::Node* node, } } - if ((*teller)(op_type, desc, use_no_calib_int8)) return true; + if (use_no_calib_int8) { + return int8_teller_set.count(op_type); + } else { + return teller_set.count(op_type); + } } + private: + // use this set for no calib int8. + std::unordered_set int8_teller_set{ + "mul", + "matmul", + "conv2d", + "conv2d_fusion", + "pool2d", + "relu", + "elu", + "selu", + "softsign", + "softplus", + "stanh", + "thresholded_relu", + "exp", + "log", + "sqrt", + "abs", + "sin", + "cos", + "tan", + "sinh", + "cosh", + "asin", + "acos", + "atan", + "asinh", + "atanh", + "ceil", + "floor", + "erf", + "softmax", + "sigmoid", + "hard_swish", + "depthwise_conv2d", + "batch_norm", + "concat", + "tanh", + "pad", + "elementwise_add", + "elementwise_sub", + "elementwise_mul", + "elementwise_div", + "elementwise_pow", + "equal", + "dropout", + "prelu", + "conv2d_transpose", + "depthwise_conv2d_transpose", + "leaky_relu", + "fc", + "shuffle_channel", + "swish", + "silu", + "split", + "instance_norm", + "gelu", + "layer_norm", + "scale", + "stack", + "transpose2", + "transpose", + "top_k", + "top_k_v2", + "flatten2", + "flatten", + "gather", + "gather_nd", + "yolo_box", + "yolo_box_head", + "arg_max", + "roi_align", + "affine_channel", + "nearest_interp", + "anchor_generator", + "reduce_sum", + "reduce_mean", + "conv3d", + "conv3d_transpose", + "mish", + "nearest_interp_v2", + "bilinear_interp_v2", + "pool3d", + "deformable_conv", + "relu6", + "hard_sigmoid", + "clip", + "fused_embedding_eltwise_layernorm", + "multihead_matmul", + "skip_layernorm", + "slice", + "strided_slice", + "fused_preln_embedding_eltwise_layernorm", + "preln_residual_bias", + "c_allreduce_sum", + "c_allreduce_min", + "c_allreduce_max", + "c_allreduce_prod", + "roll", + "cast", + "preln_skip_layernorm", + "transformer_input_convert", + "recover_padding", + "remove_padding", + "fill_constant", + "sum", + "shape", + "squeeze2", + "unsqueeze2", + "layernorm_shift_partition"}; + std::unordered_set teller_set{ + "mul", + "matmul", + "conv2d", + "conv2d_fusion", + "pool2d", + "relu", + "elu", + "selu", + "softsign", + "softplus", + "stanh", + "thresholded_relu", + "exp", + "log", + "sqrt", + "abs", + "sin", + "cos", + "tan", + "sinh", + "cosh", + "asin", + "acos", + "atan", + "asinh", + "atanh", + "ceil", + "floor", + "erf", + "softmax", + "sigmoid", + "hard_swish", + "depthwise_conv2d", + "batch_norm", + "concat", + "tanh", + "pad", + "elementwise_add", + "elementwise_sub", + "elementwise_mul", + "elementwise_div", + "elementwise_pow", + "equal", + "dropout", + "prelu", + "conv2d_transpose", + "depthwise_conv2d_transpose", + "leaky_relu", + "fc", + "shuffle_channel", + "swish", + "silu", + "split", + "instance_norm", + "gelu", + "layer_norm", + "scale", + "stack", + "transpose2", + "transpose", + "top_k", + "top_k_v2", + "flatten2", + "flatten", + "gather", + "gather_nd", + "yolo_box", + "yolo_box_head", + "arg_max", + "roi_align", + "affine_channel", + "nearest_interp", + "anchor_generator", + "reduce_sum", + "reduce_mean", + "conv3d", + "conv3d_transpose", + "mish", + "bilinear_interp_v2", + "nearest_interp_v2", + "pool3d", + "deformable_conv", + "relu6", + "hard_sigmoid", + "clip", + "fused_embedding_eltwise_layernorm", + "multihead_matmul", + "skip_layernorm", + "slice", + "strided_slice", + "fused_preln_embedding_eltwise_layernorm", + "preln_skip_layernorm", + "preln_residual_bias", + "c_allreduce_sum", + "c_allreduce_min", + "c_allreduce_max", + "c_allreduce_prod", + "roll", + "cast", + "transformer_input_convert", + "recover_padding", + "remove_padding", + "fill_constant", + "sum", + "shape", + "squeeze2", + "unsqueeze2", + "fused_token_prune", + "layernorm_shift_partition"}; +}; + +struct GenericPluginTeller : public Teller { + public: + GenericPluginTeller() {} + bool operator()(const framework::OpDesc& desc, + bool use_no_calib_int8 = false, + bool with_dynamic_shape = false) override { + const std::string op_type = desc.Type(); + // only consider dynamic_shape mode + if (!with_dynamic_shape) { + return false; + } + + if (use_no_calib_int8) { + return false; + } else { + framework::InitDefaultKernelSignatureMap(); + bool res = phi::OpUtilsMap::Instance().HasArgumentMappingFn(op_type) || + phi::DefaultKernelSignatureMap::Instance().Has(op_type); + if (!res) { + VLOG(3) << op_type << " has no KernelSignature"; + return false; + } + res = phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type); + if (!res) { + VLOG(3) << op_type << " has no CompatiblePhiKernel in phi."; + return false; + } + auto& dynamic_infermeta_factory = + tensorrt::DynamicMetaFnFactory::Instance(); + res = dynamic_infermeta_factory.Contains(op_type); + if (!res) { + VLOG(3) << op_type << " has no DynamicMetaFn."; + return false; + } + return true; + } + } +}; + +struct CustomPluginTeller : public Teller { + public: + CustomPluginTeller() {} + bool operator()(const framework::OpDesc& desc, + bool use_no_calib_int8 = false, + bool with_dynamic_shape = false) override { + const std::string op_type = desc.Type(); + std::string expect_plugin_name; + + if (with_dynamic_shape) { + expect_plugin_name = op_type + "_paddle_trt_dynamic_plugin"; + } else { + expect_plugin_name = op_type + "_paddle_trt_plugin"; + } + + int num = 0; + auto creators = GetPluginRegistry()->getPluginCreatorList(&num); + + for (int i = 0; i < num; i++) { + if (std::string(creators[i]->getPluginName()) == expect_plugin_name) + return true; + } + return false; + } +}; + +bool OpTeller::Tell(const framework::ir::Node* node, + bool use_no_calib_int8, + bool with_dynamic_shape) { + const std::string op_type = node->Op()->Type(); + const framework::OpDesc desc = *node->Op(); + auto& default_teller = GetDefaultTeller(); + if ((*default_teller)(desc, use_no_calib_int8, with_dynamic_shape)) { + SetOpConverterType(op_type, OpConverterType::Default); + return true; + } + auto& generic_plugin_teller = GetGenericPluginTeller(); + if ((*generic_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape)) { + SetOpConverterType(op_type, OpConverterType::GenericPluginCreater); + return true; + } + auto& custom_plugin_teller = GetCustomPluginTeller(); + if ((*custom_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape)) { + SetOpConverterType(op_type, OpConverterType::CustomPluginCreater); + return true; + } return false; } -OpTeller::OpTeller() { tellers_.emplace_back(new SimpleOpTypeSetTeller); } +OpTeller::OpTeller() { + tellers_.emplace_back(new tensorrt::SimpleOpTypeSetTeller); + tellers_.emplace_back(new tensorrt::GenericPluginTeller); + tellers_.emplace_back(new tensorrt::CustomPluginTeller); +} } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/op_teller.h b/paddle/fluid/inference/tensorrt/op_teller.h index 1a6ce092a18..2fa3dc36121 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.h +++ b/paddle/fluid/inference/tensorrt/op_teller.h @@ -38,9 +38,9 @@ namespace tensorrt { * issues such as op_desc. */ struct Teller { - virtual bool operator()(const std::string& op_type, - const framework::OpDesc& desc, - bool use_no_calib_int8) = 0; + virtual bool operator()(const framework::OpDesc& desc, + bool use_no_calib_int8 = false, + bool with_dynamic_shape = false) = 0; virtual ~Teller() = default; }; @@ -55,9 +55,15 @@ struct Teller { *}; */ +enum class OpConverterType { + Default = 0, + GenericPluginCreater, + CustomPluginCreater +}; /* * class OpTeller helps to tell whether a fluid - * operator can be transformed to a TensorRT layer. + * operator can be transformed to a TensorRT layer + * and use which kind of OpConverter */ class OpTeller { public: @@ -70,11 +76,26 @@ class OpTeller { bool use_no_calib_int8 = false, bool with_dynamic_shape = false); + std::unique_ptr& GetDefaultTeller() { return tellers_.at(0); } + + std::unique_ptr& GetGenericPluginTeller() { return tellers_.at(1); } + + std::unique_ptr& GetCustomPluginTeller() { return tellers_.at(2); } + + void SetOpConverterType(std::string name, OpConverterType type) { + op_converter_type_map_[name] = type; + } + + const std::map& GetOpConverterTypeMap() const { + return op_converter_type_map_; + } + private: OpTeller(); private: std::vector> tellers_; + std::map op_converter_type_map_; }; } // namespace tensorrt diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index f602714f211..9fe02cd731d 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -32,7 +32,8 @@ list( c_allreduce_op_plugin.cu preln_residual_bias_plugin.cu fused_token_prune_op_plugin.cu - layernorm_shift_partition_op.cu) + layernorm_shift_partition_op.cu + generic_plugin.cu) if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) list(APPEND TRT_FILES spmm_plugin.cu) @@ -41,7 +42,13 @@ endif() nv_library( tensorrt_plugin SRCS ${TRT_FILES} - DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) + DEPS enforce + tensorrt_engine + prelu + tensor + bert_encoder_functor + tensorrt_dynamic_shape_infermeta_factory + tensorrt_plugin_arg_mapping_context) nv_test( test_split_plugin diff --git a/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu new file mode 100644 index 00000000000..2fc6e881e8e --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu @@ -0,0 +1,463 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tensorrt/plugin/generic_plugin.h" +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/op_kernel_type.h" +#include "paddle/fluid/framework/phi_utils.h" +#include "paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_registry.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/compat/op_utils.h" +#include "paddle/phi/core/kernel_context.h" +#include "paddle/phi/core/kernel_factory.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +void BuildPhiKernelContextAttr(const framework::OpDesc& op_desc, + phi::KernelContext* kernel_context, + const phi::KernelSignature& signature, + const phi::Kernel& phi_kernel) { + const phi::KernelArgsDef& args_def = phi_kernel.args_def(); + const auto& attr_names = signature.attr_names; + const auto& attr_defs = args_def.attribute_defs(); + + PADDLE_ENFORCE_EQ( + attr_names.size(), + attr_defs.size(), + platform::errors::InvalidArgument( + "The attr_names.size() should be equal to attr_defs.size().")); + + framework::AttrReader attr_reader(op_desc.GetAttrMap()); + + for (size_t k = 0; k < attr_names.size(); ++k) { + auto attr_name = attr_names[k]; + auto* attr_ptr = attr_reader.GetAttr(attr_name); + if (attr_ptr) { + switch (attr_defs[k].type_index) { + case phi::AttributeType::SCALAR: { + auto& attr = *attr_ptr; + switch (AttrTypeID(attr)) { + case framework::proto::AttrType::FLOAT: + return kernel_context->EmplaceBackAttr( + phi::Scalar(PADDLE_GET_CONST(float, attr))); + break; + case framework::proto::AttrType::INT: + return kernel_context->EmplaceBackAttr( + phi::Scalar(PADDLE_GET_CONST(int, attr))); + break; + case framework::proto::AttrType::STRING: + return kernel_context->EmplaceBackAttr( + phi::Scalar(PADDLE_GET_CONST(std::string, attr))); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to Scalar when " + "ProtoAttr2PhiAttr.", + attr_name)); + } + } break; + + case phi::AttributeType::INT_ARRAY: { + auto& attr = *attr_ptr; + switch (AttrTypeID(attr)) { + case framework::proto::AttrType::INTS: + kernel_context->EmplaceBackAttr(std::move( + phi::IntArray(PADDLE_GET_CONST(std::vector, attr)))); + break; + case framework::proto::AttrType::LONGS: + kernel_context->EmplaceBackAttr(std::move( + phi::IntArray(PADDLE_GET_CONST(std::vector, attr)))); + break; + case framework::proto::AttrType::INT: + kernel_context->EmplaceBackAttr( + phi::IntArray({PADDLE_GET_CONST(int, attr)})); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to IntArray when " + "ProtoAttr2PhiAttr.", + attr_name)); + } + } break; + + case phi::AttributeType::SCALARS: { + auto& attr = *attr_ptr; + switch (AttrTypeID(attr)) { + case framework::proto::AttrType::INTS: { + const auto& vec = PADDLE_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } break; + case framework::proto::AttrType::LONGS: { + const auto& vec = PADDLE_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } break; + case framework::proto::AttrType::FLOATS: { + const auto& vec = PADDLE_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } break; + case framework::proto::AttrType::FLOAT64S: { + const auto& vec = PADDLE_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to vector when " + "ProtoAttr2PhiAttr.", + attr_name)); + } + } break; + + default: { + auto& attr = *attr_ptr; + switch (attr_defs[k].type_index) { + case phi::AttributeType::FLOAT32: + kernel_context->EmplaceBackAttr(PADDLE_GET_CONST(float, attr)); + break; + case phi::AttributeType::INT32: + kernel_context->EmplaceBackAttr(PADDLE_GET_CONST(int, attr)); + break; + case phi::AttributeType::BOOL: + kernel_context->EmplaceBackAttr(PADDLE_GET_CONST(bool, attr)); + break; + case phi::AttributeType::INT64: + kernel_context->EmplaceBackAttr(PADDLE_GET_CONST(int64_t, attr)); + break; + case phi::AttributeType::INT32S: + kernel_context->EmplaceBackAttr( + PADDLE_GET_CONST(std::vector, attr)); + break; + case phi::AttributeType::DATA_TYPE: { + auto data_type = paddle::framework::TransToPhiDataType( + static_cast( + PADDLE_GET_CONST(int, attr))); + kernel_context->EmplaceBackAttr(data_type); + } break; + case phi::AttributeType::STRING: + kernel_context->EmplaceBackAttr( + PADDLE_GET_CONST(std::string, attr)); + break; + case phi::AttributeType::INT64S: + switch (AttrTypeID(attr)) { + case framework::proto::AttrType::LONGS: + kernel_context->EmplaceBackAttr( + PADDLE_GET_CONST(std::vector, attr)); + break; + case framework::proto::AttrType::INTS: { + const auto& vector_int_attr = + PADDLE_GET_CONST(std::vector, attr); + const std::vector vector_int64_attr( + vector_int_attr.begin(), vector_int_attr.end()); + kernel_context->EmplaceBackAttr(vector_int64_attr); + } break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to vector " + "when ProtoAttr2PhiAttr.", + attr_name)); + } + break; + case phi::AttributeType::FLOAT32S: + kernel_context->EmplaceBackAttr( + PADDLE_GET_CONST(std::vector, attr)); + break; + case phi::AttributeType::STRINGS: + kernel_context->EmplaceBackAttr( + PADDLE_GET_CONST(std::vector, attr)); + break; + case phi::AttributeType::BOOLS: + kernel_context->EmplaceBackAttr( + PADDLE_GET_CONST(std::vector, attr)); + break; + case phi::AttributeType::FLOAT64S: + kernel_context->EmplaceBackAttr( + PADDLE_GET_CONST(std::vector, attr)); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` when construct " + "ProtoAttr2PhiAttr.", + attr_name)); + } + } + } + } + } +} + +GenericPlugin::GenericPlugin( + const paddle::framework::proto::OpDesc& proto_op_desc, + const InputOutPutVarInfo& in_out_info) { + proto_op_desc_ = proto_op_desc; + op_desc_ = std::move(framework::OpDesc(proto_op_desc_, nullptr)); + proto_op_desc_.SerializeToString(&op_meta_data_); + inputs_data_type_ = in_out_info.inputs_data_type; + outputs_data_type_ = in_out_info.outputs_data_type; +} + +GenericPlugin::GenericPlugin( + const paddle::framework::proto::OpDesc& proto_op_desc, + const std::vector& inputs_data_type, + const std::vector& outputs_data_type) { + proto_op_desc_ = proto_op_desc; + op_desc_ = std::move(framework::OpDesc(proto_op_desc_, nullptr)); + proto_op_desc_.SerializeToString(&op_meta_data_); + inputs_data_type_ = inputs_data_type; + outputs_data_type_ = outputs_data_type; +} + +GenericPlugin::GenericPlugin(void const* serial_data, size_t serial_length) { + DeserializeValue(&serial_data, &serial_length, &inputs_data_type_); + DeserializeValue(&serial_data, &serial_length, &outputs_data_type_); + std::string op_meta_data((char*)(serial_data), serial_length); // NOLINT + op_meta_data_ = std::move(op_meta_data); + proto_op_desc_.ParseFromString(op_meta_data_); + op_desc_ = std::move(framework::OpDesc(proto_op_desc_, nullptr)); +} + +int GenericPlugin::getNbOutputs() const TRT_NOEXCEPT { + int res = 0; + for (auto& i : op_desc_.Outputs()) { + if (!i.second.empty()) res += i.second.size(); + } + return res; +} + +int GenericPlugin::getNbInputs() const TRT_NOEXCEPT { + int res = 0; + for (auto& i : op_desc_.Inputs()) { + if (!i.second.empty()) res += i.second.size(); + } + return res; +} + +nvinfer1::IPluginV2DynamicExt* GenericPlugin::clone() const TRT_NOEXCEPT { + nvinfer1::IPluginV2DynamicExt* plugin = + new GenericPlugin(proto_op_desc_, inputs_data_type_, outputs_data_type_); + plugin->initialize(); + return plugin; +} + +void GenericPlugin::serialize(void* buffer) const TRT_NOEXCEPT { + // inputs_data_type_ + SerializeValue(&buffer, inputs_data_type_); + // outputs_data_type_ + SerializeValue(&buffer, outputs_data_type_); + // serialize op_meta_data_ + std::memcpy(buffer, op_meta_data_.c_str(), op_meta_data_.size()); + reinterpret_cast(buffer) += op_meta_data_.size(); +} + +bool GenericPlugin::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc* in_out, + int nb_inputs, + int nb_outputs) TRT_NOEXCEPT { + return true; +} + +nvinfer1::DataType GenericPlugin::getOutputDataType( + int index, + const nvinfer1::DataType* input_types, + int nb_inputs) const TRT_NOEXCEPT { + return input_types[0]; +} + +int GenericPlugin::initialize() TRT_NOEXCEPT { + std::string op_type = op_desc_.Type(); + + phi::KernelSignature phi_kernel_signature; + if (phi::OpUtilsMap::Instance().HasArgumentMappingFn(op_type)) { + const phi::ArgumentMappingFn* argument_mapping_func = + phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type); + PluginArgumentMappingContext argument_mapping_context(&op_desc_); + phi_kernel_signature = (*argument_mapping_func)(argument_mapping_context); + } else { + phi_kernel_signature = + phi::DefaultKernelSignatureMap::Instance().Get(op_type); + } + + phi::KernelKey phi_kernel_key( + phi::Backend::GPU, phi::DataLayout::ANY, phi::DataType::FLOAT32); + + PADDLE_ENFORCE_EQ( + phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type), + true, + platform::errors::Fatal("%s has no compatible phi kernel!", + op_type.c_str())); + + const phi::Kernel& phi_kernel = phi::KernelFactory::Instance().SelectKernel( + phi_kernel_signature.name, phi_kernel_key); + phi_kernel_ = &phi_kernel; + + PADDLE_ENFORCE_EQ(phi_kernel_->IsValid(), + true, + platform::errors::Fatal("%s phi kernel is invalid!.", + phi_kernel_signature.name)); + + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + platform::CUDAPlace place(platform::GetCurrentDeviceId()); + auto* dev_ctx = static_cast(pool.Get(place)); + + phi_kernel_context_ = new phi::KernelContext(dev_ctx); + dense_tensor_inputs_ = new std::vector(getNbInputs()); + dense_tensor_outputs_ = new std::vector(getNbOutputs()); + + BuildPhiKernelContextAttr( + op_desc_, phi_kernel_context_, phi_kernel_signature, phi_kernel); + return 0; +} + +nvinfer1::DimsExprs GenericPlugin::getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT { + CHECK(output_index < getNbOutputs()); + auto& dynamic_infermeta_factory = tensorrt::DynamicMetaFnFactory::Instance(); + PADDLE_ENFORCE_EQ(dynamic_infermeta_factory.Contains(op_desc_.Type()), + true, + platform::errors::InvalidArgument( + "The %s op has no dynamic plugin infershape function!", + op_desc_.Type().c_str())); + + auto* infershape_func = dynamic_infermeta_factory.Get(op_desc_.Type()); + return infershape_func( + output_index, inputs, nb_inputs, expr_builder, op_desc_); +} + +void GenericPlugin::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* in, + int nb_inputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nb_outputs) TRT_NOEXCEPT { + CHECK(phi_kernel_context_); + CHECK(phi_kernel_); + CHECK(nb_inputs == getNbInputs()); + CHECK(nb_outputs == getNbOutputs()); +} + +// Shutdown the layer. This is called when the engine is destroyed +void GenericPlugin::terminate() TRT_NOEXCEPT { + delete phi_kernel_context_; + delete dense_tensor_inputs_; + delete dense_tensor_outputs_; +} + +int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc, + const nvinfer1::PluginTensorDesc* output_desc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT { + platform::CUDAPlace place(platform::GetCurrentDeviceId()); + + // [TODO]now generic plugin do not support FP16 and INT8 precision + auto protoType2PhiType = [](int proto_type) -> phi::DataType { + if (proto_type == + static_cast(framework::proto::VarType_Type::VarType_Type_FP32)) + return phi::DataType::FLOAT32; + else if (proto_type == + static_cast( + framework::proto::VarType_Type::VarType_Type_INT64) || + proto_type == + static_cast( + framework::proto::VarType_Type::VarType_Type_INT32)) + return phi::DataType::INT32; + else if (proto_type == + static_cast( + framework::proto::VarType_Type::VarType_Type_BOOL)) + return phi::DataType::BOOL; + else + CHECK(false) << "precision is not supported"; + }; + + // input + for (int i = 0; i < getNbInputs(); i++) { + auto const& input_dims = input_desc[i].dims; + + std::vector input_shape; + for (int j = 0; j < input_dims.nbDims; j++) + input_shape.push_back(input_dims.d[j]); + + int input_numel = 1; + for (int k = 0; k < input_shape.size(); k++) input_numel *= input_shape[k]; + + phi::DenseTensorMeta input_meta(protoType2PhiType(inputs_data_type_[i]), + phi::make_ddim(input_shape)); + std::shared_ptr input_alloc( + new phi::Allocation((void*)(inputs[i]), // NOLINT + input_numel * sizeof(int32_t), + place)); + (*dense_tensor_inputs_)[i] = + std::move(phi::DenseTensor(input_alloc, input_meta)); + phi_kernel_context_->EmplaceBackInput(&((*dense_tensor_inputs_)[i])); + } + + // output + for (int i = 0; i < getNbOutputs(); i++) { + auto const& output_dims = output_desc[i].dims; + + std::vector output_shape; + for (int j = 0; j < output_dims.nbDims; j++) + output_shape.push_back(output_dims.d[j]); + + int output_numel = 1; + for (int k = 0; k < output_shape.size(); k++) + output_numel *= output_shape[k]; + + phi::DenseTensorMeta output_meta(protoType2PhiType(outputs_data_type_[i]), + phi::make_ddim(output_shape)); + std::shared_ptr output_alloc( + new phi::Allocation(reinterpret_cast(outputs[i]), + output_numel * sizeof(float), + place)); + phi::DenseTensor output_densetonsor(output_alloc, output_meta); + (*dense_tensor_outputs_)[i] = + std::move(phi::DenseTensor(output_alloc, output_meta)); + phi_kernel_context_->EmplaceBackOutput(&((*dense_tensor_outputs_)[i])); + } + + (*phi_kernel_)(phi_kernel_context_); + + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/generic_plugin.h b/paddle/fluid/inference/tensorrt/plugin/generic_plugin.h new file mode 100644 index 00000000000..39730937af2 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/generic_plugin.h @@ -0,0 +1,162 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/helper.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h" +#include "paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h" +#include "paddle/fluid/memory/allocation/cuda_allocator.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_context.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +void BuildPhiKernelContextAttr(const framework::OpDesc& op_desc, + phi::KernelContext* kernel_context, + const phi::KernelSignature& signature, + const phi::Kernel& phi_kernel); + +class GenericPlugin : public DynamicPluginTensorRT { + public: + struct InputOutPutVarInfo { + std::vector inputs_data_type; + std::vector outputs_data_type; + }; + + public: + GenericPlugin() {} + + GenericPlugin(const paddle::framework::proto::OpDesc& proto_op_desc, + const InputOutPutVarInfo& in_out_info); + + GenericPlugin(const paddle::framework::proto::OpDesc& proto_op_desc, + const std::vector& inputs_data_type, + const std::vector& outputs_data_type); + + // It was used for tensorrt deserialization. + // It should not be called by users. + GenericPlugin(void const* serialData, size_t serialLength); + + // IPluginV2 method + const char* getPluginType() const TRT_NOEXCEPT override { + return "generic_plugin"; + } + + int getNbOutputs() const TRT_NOEXCEPT override; + + int getNbInputs() const TRT_NOEXCEPT; + + // Initialize the layer for execution. + int initialize() TRT_NOEXCEPT override; + + // Shutdown the layer. This is called when the engine is destroyed + void terminate() TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT{}; + + size_t getSerializationSize() const TRT_NOEXCEPT { + return op_meta_data_.size() + SerializedSize(inputs_data_type_) + + SerializedSize(outputs_data_type_); + } + + void serialize(void* buffer) const TRT_NOEXCEPT; + + // The Func in IPluginV2 + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT; + + nvinfer1::DimsExprs getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) // NOLINT + TRT_NOEXCEPT; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* in_out, + int nb_inputs, + int nb_outputs) TRT_NOEXCEPT; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nb_inputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nb_outputs) TRT_NOEXCEPT; + + int enqueue(const nvinfer1::PluginTensorDesc* input_desc, + const nvinfer1::PluginTensorDesc* output_desc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT; + + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* input_types, + int nb_inputs) const TRT_NOEXCEPT; + + private: + std::string op_meta_data_; + framework::proto::OpDesc proto_op_desc_; + framework::OpDesc op_desc_; + + private: + phi::KernelContext* phi_kernel_context_; + const phi::Kernel* phi_kernel_; + std::vector* dense_tensor_inputs_; + std::vector* dense_tensor_outputs_; + + private: + InputOutPutVarInfo in_out_info_; + std::vector inputs_data_type_; + std::vector outputs_data_type_; +}; + +class GenericPluginCreator : public TensorRTPluginCreator { + public: + const char* getPluginName() const TRT_NOEXCEPT override { + return "generic_plugin"; + } + + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + nvinfer1::IPluginV2DynamicExt* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) + TRT_NOEXCEPT override { + return new GenericPlugin(serial_data, serial_length); + } +}; +REGISTER_TRT_PLUGIN_V2(GenericPluginCreator); + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.h index df404ae3e10..433ff37aac7 100644 --- a/paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.h @@ -125,10 +125,11 @@ class MishPluginDynamic : public DynamicPluginTensorRT { size_t getSerializationSize() const TRT_NOEXCEPT override; void serialize(void* buffer) const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int output_index, - const nvinfer1::DimsExprs* inputs, - int nb_inputs, - nvinfer1::IExprBuilder& expr_builder) + nvinfer1::DimsExprs getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) // NOLINT TRT_NOEXCEPT override; bool supportsFormatCombination(int pos, diff --git a/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc b/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc new file mode 100644 index 00000000000..5d9998d2556 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc @@ -0,0 +1,122 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +bool PluginArgumentMappingContext::HasInput(const std::string& name) const { + auto inputs = op_desc_ptr_->Inputs(); + for (auto& i : inputs) { + if (i.first == name && !i.second.empty()) return true; + } + return false; +} + +bool PluginArgumentMappingContext::HasOutput(const std::string& name) const { + auto outputs = op_desc_ptr_->Outputs(); + for (auto& i : outputs) { + if (i.first == name && !i.second.empty()) return true; + } + return false; +} + +bool PluginArgumentMappingContext::HasAttr(const std::string& name) const { + return op_desc_ptr_->HasAttr(name); +} + +paddle::any PluginArgumentMappingContext::Attr( + const std::string& attr_name) const { + auto attr_type = op_desc_ptr_->GetAttrType(attr_name); + switch (attr_type) { + case framework::proto::AttrType::INT: { + return PADDLE_GET_CONST(int, op_desc_ptr_->GetAttr(attr_name)); + break; + }; + case framework::proto::AttrType::FLOAT: { + return PADDLE_GET_CONST(float, op_desc_ptr_->GetAttr(attr_name)); + break; + }; + case framework::proto::AttrType::STRING: { + return PADDLE_GET_CONST(std::string, op_desc_ptr_->GetAttr(attr_name)); + break; + }; + case framework::proto::AttrType::INTS: { + return PADDLE_GET_CONST(std::vector, + op_desc_ptr_->GetAttr(attr_name)); + break; + }; + case framework::proto::AttrType::FLOATS: { + return PADDLE_GET_CONST(std::vector, + op_desc_ptr_->GetAttr(attr_name)); + break; + }; + case framework::proto::AttrType::STRINGS: { + return PADDLE_GET_CONST(std::vector, + op_desc_ptr_->GetAttr(attr_name)); + break; + }; + case framework::proto::AttrType::BOOLEAN: { + return PADDLE_GET_CONST(bool, op_desc_ptr_->GetAttr(attr_name)); + break; + }; + case framework::proto::AttrType::BOOLEANS: { + return PADDLE_GET_CONST(std::vector, + op_desc_ptr_->GetAttr(attr_name)); + break; + }; + default: { + LOG(ERROR) << "Can't conver op's attribute [" << attr_name + << "] to paddle any."; + } + } + return paddle::any(); +} + +size_t PluginArgumentMappingContext::InputSize(const std::string& name) const { + return op_desc_ptr_->Inputs().at(name).size(); +} +size_t PluginArgumentMappingContext::OutputSize(const std::string& name) const { + return op_desc_ptr_->Outputs().at(name).size(); +} +bool PluginArgumentMappingContext::IsDenseTensorInput( + const std::string& name) const { + return false; +} +bool PluginArgumentMappingContext::IsDenseTensorInputs( + const std::string& name) const { + return false; +} +bool PluginArgumentMappingContext::IsSelectedRowsInput( + const std::string& name) const { + return false; +} +bool PluginArgumentMappingContext::IsDenseTensorVectorInput( + const std::string& name) const { + return false; +} + +bool PluginArgumentMappingContext::IsDenseTensorOutput( + const std::string& name) const { + return false; +} +bool PluginArgumentMappingContext::IsSelectedRowsOutput( + const std::string& name) const { + return false; +} +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h b/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h new file mode 100644 index 00000000000..35229a5ab79 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h @@ -0,0 +1,62 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/phi/core/compat/arg_map_context.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext { + public: + explicit PluginArgumentMappingContext(framework::OpDesc* op_desc_ptr) + : op_desc_ptr_(op_desc_ptr) {} + + bool HasInput(const std::string& name) const override; + + bool HasOutput(const std::string& name) const override; + + bool HasAttr(const std::string& name) const override; + + paddle::any Attr(const std::string& attr_name) const override; + + size_t InputSize(const std::string& name) const override; + + size_t OutputSize(const std::string& name) const override; + + bool IsDenseTensorInput(const std::string& name) const override; + + bool IsDenseTensorInputs(const std::string& name) const override; + + bool IsSelectedRowsInput(const std::string& name) const override; + + bool IsDenseTensorVectorInput(const std::string& name) const override; + + bool IsDenseTensorOutput(const std::string& name) const override; + + bool IsSelectedRowsOutput(const std::string& name) const override; + + bool IsForInferShape() const override { return false; } + + private: + framework::OpDesc* op_desc_ptr_; +}; +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/test_arg_mapping_context.cc b/paddle/fluid/inference/tensorrt/test_arg_mapping_context.cc new file mode 100644 index 00000000000..75716a91f57 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/test_arg_mapping_context.cc @@ -0,0 +1,132 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(ArgMappingContexTest, BasicFunction) { + paddle::framework::proto::OpDesc op; + op.set_type("imaged_op"); + auto *input_var = op.add_inputs(); + input_var->set_parameter("X"); + *input_var->add_arguments() = "input"; + + auto *output_var = op.add_outputs(); + output_var->set_parameter("Out"); + *output_var->add_arguments() = "output"; + + auto *attr = op.add_attrs(); + attr->set_name("int_attr"); + attr->set_type(paddle::framework::proto::AttrType::INT); + attr->set_i(1); + + attr = op.add_attrs(); + attr->set_name("float_attr"); + attr->set_type(paddle::framework::proto::AttrType::FLOAT); + attr->set_f(1.0); + + attr = op.add_attrs(); + attr->set_name("string_attr"); + attr->set_type(paddle::framework::proto::AttrType::STRING); + attr->set_s("1"); + + attr = op.add_attrs(); + attr->set_name("bool_attr"); + attr->set_type(paddle::framework::proto::AttrType::BOOLEAN); + attr->set_b(true); + + attr = op.add_attrs(); + attr->set_name("ints_attr"); + attr->set_type(paddle::framework::proto::AttrType::INTS); + attr->add_ints(1); + attr->add_ints(2); + + attr = op.add_attrs(); + attr->set_name("floats_attr"); + attr->set_type(paddle::framework::proto::AttrType::FLOATS); + attr->add_floats(1.0); + attr->add_floats(2.0); + + attr = op.add_attrs(); + attr->set_name("strings_attr"); + attr->set_type(paddle::framework::proto::AttrType::STRINGS); + attr->add_strings("1"); + attr->add_strings("2"); + + attr = op.add_attrs(); + attr->set_name("bools_attr"); + attr->set_type(paddle::framework::proto::AttrType::BOOLEANS); + attr->add_bools(true); + attr->add_bools(true); + + framework::OpDesc op_desc(op, nullptr); + PluginArgumentMappingContext context(&op_desc); + + EXPECT_EQ(context.HasInput("X"), true); + EXPECT_EQ(context.HasOutput("Out"), true); + EXPECT_EQ(context.HasAttr("int_attr"), true); + + int int_attr = any_cast(context.Attr("int_attr")); + EXPECT_EQ(int_attr, 1); + + float flaot_attr = any_cast(context.Attr("float_attr")); + EXPECT_EQ(flaot_attr, 1); + + std::string string_attr = any_cast(context.Attr("string_attr")); + EXPECT_EQ(string_attr, "1"); + + bool bool_attr = any_cast(context.Attr("bool_attr")); + EXPECT_EQ(bool_attr, true); + + std::vector ints_attr = + any_cast>(context.Attr("ints_attr")); + EXPECT_EQ(ints_attr[0], 1); + EXPECT_EQ(ints_attr[1], 2); + + std::vector floats_attr = + any_cast>(context.Attr("floats_attr")); + EXPECT_EQ(floats_attr[0], 1.0); + EXPECT_EQ(floats_attr[1], 2.0); + + std::vector strings_attr = + any_cast>(context.Attr("strings_attr")); + EXPECT_EQ(strings_attr[0], "1"); + EXPECT_EQ(strings_attr[1], "2"); + + std::vector bools_attr = + any_cast>(context.Attr("bools_attr")); + EXPECT_EQ(bools_attr[0], true); + EXPECT_EQ(bools_attr[1], true); + + EXPECT_EQ(context.InputSize("X"), true); + EXPECT_EQ(context.OutputSize("Out"), true); + EXPECT_EQ(context.IsDenseTensorInput("X"), false); + EXPECT_EQ(context.IsDenseTensorInputs("X"), false); + EXPECT_EQ(context.IsSelectedRowsInput("X"), false); + EXPECT_EQ(context.IsDenseTensorVectorInput("X"), false); + + EXPECT_EQ(context.IsDenseTensorOutput("Out"), false); + EXPECT_EQ(context.IsSelectedRowsOutput("Out"), false); + EXPECT_EQ(context.IsForInferShape(), false); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc index 7b58a1bb7d6..eab919135c7 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc @@ -159,6 +159,8 @@ void DynamicShapeTest(bool allow_build_at_runtime) { // Execute them. LOG(INFO) << "engine_op run"; + inference::tensorrt::OpTeller::Global().SetOpConverterType( + "fc", inference::tensorrt::OpConverterType::Default); engine_op->Run(scope, place); } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather_nd.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather_nd.py index 9343f1ebd7c..75b5cba9e81 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather_nd.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather_nd.py @@ -19,11 +19,15 @@ import paddle.inference as paddle_infer from functools import partial from typing import Optional, List, Callable, Dict, Any, Set import unittest +import os class TrtConvertGatherNdTest_dim_4_1(TrtLayerAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: + # The output has diff between gpu and trt in CI windows + # if ( and self.trt_param.precision == paddle_infer.PrecisionType.Half): + # return False return True def sample_program_configs(self): @@ -46,17 +50,19 @@ class TrtConvertGatherNdTest_dim_4_1(TrtLayerAutoScanTest): "op_attrs": {} }] ops = self.generate_op_config(ops_config) - - program_config = ProgramConfig( - ops=ops, - weights={}, - inputs={ - "input_data": TensorConfig(data_gen=partial(generate_input1)), - "index_data": TensorConfig(data_gen=partial(generate_input2)), - }, - outputs=["output_data"]) - - yield program_config + for i in range(10): + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": + TensorConfig(data_gen=partial(generate_input1)), + "index_data": + TensorConfig(data_gen=partial(generate_input2)), + }, + outputs=["output_data"]) + + yield program_config def sample_predictor_configs( self, program_config) -> (paddle_infer.Config, List[int], float): @@ -71,7 +77,7 @@ class TrtConvertGatherNdTest_dim_4_1(TrtLayerAutoScanTest): "index_data": [1] } self.dynamic_shape.opt_input_shape = { - "input_data": [2, 4, 64, 64], + "input_data": [2, 32, 64, 64], "index_data": [1] } @@ -94,11 +100,23 @@ class TrtConvertGatherNdTest_dim_4_1(TrtLayerAutoScanTest): # for dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), (0, 4), 1e-5 + yield self.create_inference_config(), (1, 3), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), (0, 4), 1e-5 + yield self.create_inference_config(), (1, 3), 1e-5 + + def add_skip_trt_case(self): + + def teller1(program_config, predictor_config): + if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt': + return True + return False + + self.add_skip_case( + teller1, SkipReasons.TRT_NOT_SUPPORT, + "Under Windows Ci, this case will sporadically fail.") def test(self): + self.add_skip_trt_case() self.run_test() @@ -145,14 +163,14 @@ class TrtConvertGatherNdTest_dim_4_1_2(TrtLayerAutoScanTest): def generate_dynamic_shape(attrs): self.dynamic_shape.min_input_shape = { "input_data": [1, 8, 8, 8], - "index_data": [1] + "index_data": [2] } self.dynamic_shape.max_input_shape = { "input_data": [4, 32, 64, 64], - "index_data": [4] + "index_data": [2] } self.dynamic_shape.opt_input_shape = { - "input_data": [2, 4, 64, 64], + "input_data": [2, 32, 64, 64], "index_data": [2] } @@ -175,11 +193,23 @@ class TrtConvertGatherNdTest_dim_4_1_2(TrtLayerAutoScanTest): # for dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), (0, 4), 1e-5 + yield self.create_inference_config(), (1, 3), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), (0, 4), 1e-5 + yield self.create_inference_config(), (1, 3), 1e-5 + + def add_skip_trt_case(self): + + def teller1(program_config, predictor_config): + if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt': + return True + return False + + self.add_skip_case( + teller1, SkipReasons.TRT_NOT_SUPPORT, + "Under Windows Ci, this case will sporadically fail.") def test(self): + self.add_skip_trt_case() self.run_test() @@ -226,14 +256,14 @@ class TrtConvertGatherNdTest_dim_4_2(TrtLayerAutoScanTest): def generate_dynamic_shape(attrs): self.dynamic_shape.min_input_shape = { "input_data": [1, 8, 8, 8], - "index_data": [1, 2] + "index_data": [2, 2] } self.dynamic_shape.max_input_shape = { "input_data": [4, 32, 64, 64], - "index_data": [4, 4] + "index_data": [2, 2] } self.dynamic_shape.opt_input_shape = { - "input_data": [2, 4, 64, 64], + "input_data": [2, 32, 64, 64], "index_data": [2, 2] } @@ -256,11 +286,23 @@ class TrtConvertGatherNdTest_dim_4_2(TrtLayerAutoScanTest): # for dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), (0, 4), 1e-5 + yield self.create_inference_config(), (1, 3), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), (0, 4), 1e-5 + yield self.create_inference_config(), (1, 3), 1e-5 + + def add_skip_trt_case(self): + + def teller1(program_config, predictor_config): + if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt': + return True + return False + + self.add_skip_case( + teller1, SkipReasons.TRT_NOT_SUPPORT, + "Under Windows Ci, this case will sporadically fail.") def test(self): + self.add_skip_trt_case() self.run_test() @@ -307,15 +349,15 @@ class TrtConvertGatherNdTest_dim_4_3(TrtLayerAutoScanTest): def generate_dynamic_shape(attrs): self.dynamic_shape.min_input_shape = { "input_data": [1, 8, 8, 8], - "index_data": [1, 2, 2] + "index_data": [2, 2, 4] } self.dynamic_shape.max_input_shape = { "input_data": [4, 32, 64, 64], - "index_data": [4, 4, 4] + "index_data": [2, 2, 4] } self.dynamic_shape.opt_input_shape = { - "input_data": [2, 4, 64, 64], - "index_data": [2, 2, 2] + "input_data": [2, 32, 64, 64], + "index_data": [2, 2, 4] } def clear_dynamic_shape(): @@ -337,11 +379,23 @@ class TrtConvertGatherNdTest_dim_4_3(TrtLayerAutoScanTest): # for dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), (0, 4), 1e-5 + yield self.create_inference_config(), (1, 3), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), (0, 4), 1e-5 + yield self.create_inference_config(), (1, 3), 1e-5 + + def add_skip_trt_case(self): + + def teller1(program_config, predictor_config): + if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt': + return True + return False + + self.add_skip_case( + teller1, SkipReasons.TRT_NOT_SUPPORT, + "Under Windows Ci, this case will sporadically fail.") def test(self): + self.add_skip_trt_case() self.run_test() @@ -388,11 +442,11 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest): def generate_dynamic_shape(attrs): self.dynamic_shape.min_input_shape = { "input_data": [1, 4], - "index_data": [1, 1] + "index_data": [2, 2] } self.dynamic_shape.max_input_shape = { "input_data": [4, 64], - "index_data": [4, 2] + "index_data": [2, 2] } self.dynamic_shape.opt_input_shape = { "input_data": [2, 8], @@ -418,11 +472,23 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest): # for dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), (0, 4), 1e-5 + yield self.create_inference_config(), (1, 3), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), (0, 4), 1e-5 + yield self.create_inference_config(), (1, 3), 1e-5 + + def add_skip_trt_case(self): + + def teller1(program_config, predictor_config): + if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt': + return True + return False + + self.add_skip_case( + teller1, SkipReasons.TRT_NOT_SUPPORT, + "Under Windows Ci, this case will sporadically fail.") def test(self): + self.add_skip_trt_case() self.run_test() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_yolo_box.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_yolo_box.py index cebede99e6f..fec44769391 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_yolo_box.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_yolo_box.py @@ -107,24 +107,30 @@ class TrtConvertYoloBoxTest(TrtLayerAutoScanTest): if attrs[0]['iou_aware'] == True: channel = 3 * (attrs[0]['class_num'] + 6) self.dynamic_shape.min_input_shape = { - "scale_input": [1, channel, 12, 12] + "yolo_box_input": [1, channel, 12, 12], + "imgsize": [1, 2] } self.dynamic_shape.max_input_shape = { - "scale_input": [4, channel, 24, 24] + "yolo_box_input": [4, channel, 24, 24], + "imgsize": [4, 2] } self.dynamic_shape.opt_input_shape = { - "scale_input": [1, channel, 24, 24] + "yolo_box_input": [1, channel, 24, 24], + "imgsize": [1, 2] } else: channel = 3 * (attrs[0]['class_num'] + 5) self.dynamic_shape.min_input_shape = { - "scale_input": [1, channel, 12, 12] + "yolo_box_input": [1, channel, 12, 12], + "imgsize": [1, 2] } self.dynamic_shape.max_input_shape = { - "scale_input": [4, channel, 24, 24] + "yolo_box_input": [4, channel, 24, 24], + "imgsize": [4, 2] } self.dynamic_shape.opt_input_shape = { - "scale_input": [1, channel, 24, 24] + "yolo_box_input": [1, channel, 24, 24], + "imgsize": [1, 2] } def clear_dynamic_shape(): -- GitLab