未验证 提交 07933116 编写于 作者: W weishengying 提交者: GitHub

General Plugin Mechanism (#45355) (#46070)

上级 2680a71e
......@@ -2185,6 +2185,8 @@ USE_TRT_CONVERTER(shape)
USE_TRT_CONVERTER(fill_constant)
USE_TRT_CONVERTER(fused_token_prune)
USE_TRT_CONVERTER(layernorm_shift_partition)
USE_TRT_CONVERTER(generic_plugin_creater)
USE_TRT_CONVERTER(custom_plugin_creater)
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
USE_TRT_CONVERTER(sparse_fc)
USE_TRT_CONVERTER(sparse_multihead_matmul)
......
......@@ -12,10 +12,18 @@ else()
SRCS engine.cc trt_int8_calibrator.cc
DEPS ${GLOB_OPERATOR_DEPS} framework_proto device_context)
endif()
nv_library(
tensorrt_dynamic_shape_infermeta_factory
SRCS dynamic_shape_infermeta.cc
DEPS framework_proto)
nv_library(
tensorrt_plugin_arg_mapping_context
SRCS plugin_arg_mapping_context.cc
DEPS framework_proto)
nv_library(
tensorrt_op_teller
SRCS op_teller.cc
DEPS framework_proto device_context)
DEPS framework_proto device_context tensorrt_dynamic_shape_infermeta_factory)
nv_test(
test_tensorrt
SRCS test_tensorrt.cc
......@@ -24,6 +32,10 @@ nv_test(
test_tensorrt_engine
SRCS test_engine.cc test_dynamic_engine.cc
DEPS dynload_cuda tensorrt_engine tensorrt_plugin)
nv_test(
test_arg_mapping_context
SRCS test_arg_mapping_context.cc
DEPS framework_proto tensorrt_plugin_arg_mapping_context)
if(WITH_ONNXRUNTIME AND WIN32)
# Copy onnxruntime for some c++ test in Windows, since the test will
......
......@@ -76,7 +76,8 @@ list(
shape_op.cc
fill_constant_op.cc
fused_token_prune_op.cc
layernorm_shift_partition_op.cc)
layernorm_shift_partition_op.cc
generic_and_custom_plugin_creater.cc)
if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8)
list(APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc)
......@@ -85,7 +86,12 @@ endif()
nv_library(
tensorrt_converter
SRCS ${CONVERT_FILES}
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto
DEPS tensorrt_engine
tensorrt_plugin
operator
scope
framework_proto
tensorrt_op_teller
op_registry)
nv_test(
......@@ -94,6 +100,11 @@ nv_test(
DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine
tensorrt_converter)
nv_test(
test_custom_plugin_creater
SRCS test_custom_plugin_creater.cc
DEPS paddle_framework tensorrt_converter op_meta_info custom_operator)
if(WITH_ONNXRUNTIME AND WIN32)
# Copy onnxruntime for some c++ test in Windows, since the test will
# be build only in CI, so suppose the generator in Windows is Ninja.
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_meta_info_helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/generic_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Stack converter from fluid to tensorRT.
*/
class CustomPluginCreater : public OpConverter {
public:
void operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) override {
framework::OpDesc op_desc(op, nullptr);
VLOG(3) << "convert " << op_desc.Type() << " op to custom pluign layer";
std::string plugin_name;
if (engine_->with_dynamic_shape()) {
plugin_name = op_desc.Type() + "_paddle_trt_dynamic_plugin";
} else {
plugin_name = op_desc.Type() + "_paddle_trt_plugin";
}
nvinfer1::ILayer *layer = nullptr;
std::vector<nvinfer1::ITensor *> inputs;
auto &op_meta_info_map = OpMetaInfoMap::Instance();
const auto &meta_info_map = op_meta_info_map.GetMap();
auto &op_info = meta_info_map.at(op_desc.Type()).front();
// set inputs
auto &op_input_names = framework::OpMetaInfoHelper::GetInputs(op_info);
for (auto &param_name : op_input_names) {
for (auto &arg_name : op_desc.Input(param_name)) {
framework::Variable *X_v = nullptr;
X_v = scope.FindVar(arg_name);
// If this weight is not shared between ops, it need to be convtered to
// itensor
if (X_v && !engine_->GetITensorMap()->count(arg_name)) {
ConvertWeight2ITensor(scope, arg_name);
}
inputs.push_back(engine_->GetITensor(arg_name));
}
}
auto creator =
GetPluginRegistry()->getPluginCreator(plugin_name.c_str(), "1");
CHECK(creator);
// set attrs
std::vector<nvinfer1::PluginField> plugindatas;
auto &op_attrs_names = framework::OpMetaInfoHelper::GetAttrs(op_info);
auto &attrs = op_desc.GetAttrMap();
std::list<int> int_attrs;
std::list<float> float_attrs;
std::list<double> bool_attrs;
std::list<std::string> string_attrs;
std::list<std::vector<int>> ints_attrs;
std::list<std::vector<float>> floats_attrs;
for (auto &attr_name : op_attrs_names) {
nvinfer1::PluginField plugindata;
plugindata.name = attr_name.c_str();
if (op_desc.GetAttrType(attr_name) == framework::proto::AttrType::INT) {
int_attrs.push_back(PADDLE_GET_CONST(int, attrs.at(attr_name)));
plugindata.data = &int_attrs.back();
plugindata.type = nvinfer1::PluginFieldType::kINT32;
plugindata.length = 1;
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::FLOAT) {
float_attrs.push_back(PADDLE_GET_CONST(float, attrs.at(attr_name)));
plugindata.data = &float_attrs.back();
plugindata.type = nvinfer1::PluginFieldType::kFLOAT32;
plugindata.length = 1;
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::BOOLEAN) {
int_attrs.push_back(PADDLE_GET_CONST(bool, attrs.at(attr_name)));
plugindata.data = &int_attrs.back();
plugindata.type = nvinfer1::PluginFieldType::kINT32;
plugindata.length = 1;
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::STRING) {
string_attrs.push_back(
PADDLE_GET_CONST(std::string, attrs.at(attr_name)));
plugindata.data = string_attrs.back().data();
plugindata.type = nvinfer1::PluginFieldType::kCHAR;
plugindata.length =
string_attrs.back().size() + 1; // string ends with ‘\0’
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::INTS) {
ints_attrs.push_back(
PADDLE_GET_CONST(std::vector<int>, attrs.at(attr_name)));
plugindata.data = ints_attrs.back().data();
plugindata.type = nvinfer1::PluginFieldType::kINT32;
plugindata.length = ints_attrs.back().size();
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::FLOATS) {
floats_attrs.push_back(
PADDLE_GET_CONST(std::vector<float>, attrs.at(attr_name)));
plugindata.data = floats_attrs.back().data();
plugindata.type = nvinfer1::PluginFieldType::kFLOAT32;
plugindata.length = floats_attrs.back().size();
} else if (op_desc.GetAttrType(attr_name) ==
framework::proto::AttrType::BOOLEANS) {
auto bools_attr =
PADDLE_GET_CONST(std::vector<bool>, attrs.at(attr_name));
std::vector<int> convert_to_ints_attr;
for (bool i : bools_attr) convert_to_ints_attr.push_back(i);
ints_attrs.push_back(convert_to_ints_attr);
plugindata.data = ints_attrs.back().data();
plugindata.type = nvinfer1::PluginFieldType::kINT32;
plugindata.length = ints_attrs.back().size();
} else {
CHECK(false) << "UNKNOWN PluginFieldType.";
}
plugindatas.push_back(plugindata);
}
nvinfer1::PluginFieldCollection plugin_fc{(int32_t)plugindatas.size(),
plugindatas.data()};
auto *plugin = creator->createPlugin(op_desc.Type().c_str(), &plugin_fc);
CHECK(plugin);
if (engine_->with_dynamic_shape()) {
layer =
engine_->AddDynamicPlugin(inputs.data(),
inputs.size(),
(plugin::DynamicPluginTensorRT *)plugin);
} else {
layer = engine_->AddPlugin(
inputs.data(), inputs.size(), (plugin::PluginTensorRT *)plugin);
}
CHECK(layer);
// set outputs
auto &op_output_names = framework::OpMetaInfoHelper::GetOutputs(op_info);
std::vector<std::string> output_names;
for (auto &param_name : op_output_names) {
for (auto &arg_name : op_desc.Output(param_name))
output_names.push_back(arg_name);
}
RreplenishLayerAndOutput(layer, op_desc.Type(), output_names, test_mode);
}
};
class GenericPluginCreater : public OpConverter {
public:
void operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) override {
framework::OpDesc op_desc(op, nullptr);
CHECK(block_);
const framework::BlockDesc block_desc(
nullptr, const_cast<framework::proto::BlockDesc *>(block_));
nvinfer1::ILayer *layer = nullptr;
std::vector<nvinfer1::ITensor *> inputs;
phi::KernelSignature phi_kernel_signature;
if (phi::OpUtilsMap::Instance().HasArgumentMappingFn(op_desc.Type())) {
const phi::ArgumentMappingFn *argument_mapping_func =
phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_desc.Type());
PluginArgumentMappingContext argument_mapping_context(&op_desc);
phi_kernel_signature = (*argument_mapping_func)(argument_mapping_context);
} else {
phi_kernel_signature =
phi::DefaultKernelSignatureMap::Instance().Get(op_desc.Type());
}
plugin::GenericPlugin::InputOutPutVarInfo in_out_info;
for (auto &param_name : phi_kernel_signature.input_names) {
for (auto &arg_name : op_desc.Input(param_name)) {
framework::Variable *X_v = nullptr;
X_v = scope.FindVar(arg_name);
// If this weight is not shared between ops, it need to be convtered to
// itensor
if (X_v && !engine_->GetITensorMap()->count(arg_name)) {
ConvertWeight2ITensor(scope, arg_name);
}
inputs.push_back(engine_->GetITensor(arg_name));
auto *var = block_desc.FindVar(arg_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"There is no variable called %s in block.", arg_name.c_str()));
PADDLE_ENFORCE_EQ(
var->GetType(),
FluidDT::VarType_Type_LOD_TENSOR,
platform::errors::InvalidArgument("TensorRT engine only takes "
"LoDTensor as input"));
in_out_info.inputs_data_type.push_back(var->GetDataType());
}
}
std::vector<std::string> output_names;
for (auto &param_name : phi_kernel_signature.output_names) {
for (auto &arg_name : op_desc.Output(param_name)) {
output_names.push_back(arg_name);
auto *var = block_desc.FindVar(arg_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"There is no variable called %s in block.", arg_name.c_str()));
PADDLE_ENFORCE_EQ(
var->GetType(),
FluidDT::VarType_Type_LOD_TENSOR,
platform::errors::InvalidArgument("TensorRT engine only takes "
"LoDTensor as input"));
in_out_info.outputs_data_type.push_back(var->GetDataType());
}
}
plugin::GenericPlugin *plugin = new plugin::GenericPlugin(op, in_out_info);
layer = engine_->AddDynamicPlugin(inputs.data(), inputs.size(), plugin);
RreplenishLayerAndOutput(layer, op_desc.Type(), output_names, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(custom_plugin_creater, CustomPluginCreater);
REGISTER_TRT_OP_CONVERTER(generic_plugin_creater, GenericPluginCreater);
......@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
......@@ -49,11 +50,15 @@ class OpConverter {
const std::unordered_set<std::string>& parameters,
const framework::Scope& scope,
TensorRTEngine* engine,
bool test_mode = false) {
bool test_mode = false,
const framework::proto::BlockDesc* block = nullptr) {
framework::OpDesc op_desc(op, nullptr);
OpConverter* it{nullptr};
auto op_converter_type_map = OpTeller::Global().GetOpConverterTypeMap();
switch (op_converter_type_map.at(op_desc.Type())) {
case OpConverterType::Default:
if (op_desc.Type() == "mul") {
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(),
1UL,
......@@ -80,28 +85,29 @@ class OpConverter {
"Input(\"Y\").size() = %u.",
op_desc.Input("Y").size()));
int op_type_len = op_desc.Type().size();
std::string op_type = op_desc.Type().substr(op_type_len - 3, op_type_len);
std::string op_type =
op_desc.Type().substr(op_type_len - 3, op_type_len);
std::string Y = op_desc.Input("Y")[0];
if (parameters.count(Y)) {
PADDLE_ENFORCE_GT(
add_weight_op_set.count(op_type),
0,
platform::errors::Unimplemented("Unsupported elementwise type %s",
op_type.c_str()));
it = Registry<OpConverter>::Global().Lookup("elementwise_" + op_type +
"_weight");
platform::errors::Unimplemented(
"Unsupported elementwise type %s", op_type.c_str()));
it = Registry<OpConverter>::Global().Lookup("elementwise_" +
op_type + "_weight");
PADDLE_ENFORCE_NOT_NULL(
it,
platform::errors::Unimplemented("no OpConverter for optype [%s]",
op_desc.Type()));
platform::errors::Unimplemented(
"no OpConverter for optype [%s]", op_desc.Type()));
} else {
PADDLE_ENFORCE_GT(
add_tensor_op_set.count(op_type),
0,
platform::errors::Unimplemented("Unsupported elementwise type %s",
op_type.c_str()));
it = Registry<OpConverter>::Global().Lookup("elementwise_" + op_type +
"_tensor");
platform::errors::Unimplemented(
"Unsupported elementwise type %s", op_type.c_str()));
it = Registry<OpConverter>::Global().Lookup("elementwise_" +
op_type + "_tensor");
}
PADDLE_ENFORCE_NOT_NULL(
it,
......@@ -148,12 +154,31 @@ class OpConverter {
if (!it) {
it = Registry<OpConverter>::Global().Lookup(op_desc.Type());
}
break;
case OpConverterType::GenericPluginCreater:
LOG(INFO) << "There is no OpConverter for type " << op_desc.Type()
<< ", now use generic_plugin_creater!";
it = Registry<OpConverter>::Global().Lookup("generic_plugin_creater");
break;
case OpConverterType::CustomPluginCreater:
LOG(INFO) << "There is no OpConverter for type " << op_desc.Type()
<< ", now use custom_plugin_creater!";
it = Registry<OpConverter>::Global().Lookup("custom_plugin_creater");
break;
default:
CHECK(false) << "no OpConverter for optype " << op_desc.Type();
}
PADDLE_ENFORCE_NOT_NULL(
it,
platform::errors::Unimplemented("no OpConverter for optype [%s]",
op_desc.Type()));
it->SetEngine(engine);
it->SetBlockDesc(block);
(*it)(op, scope, test_mode);
size_t output_num = op_desc.OutputNames().size();
......@@ -257,7 +282,7 @@ class OpConverter {
}
for (int i = 0; i < block.ops_size(); i++) {
const auto& op = block.ops(i);
ConvertOp(op, parameters, scope, engine);
ConvertOp(op, parameters, scope, engine, false, &block);
}
for (int i = 0; i < engine->network()->getNbLayers(); i++) {
auto layer = engine->network()->getLayer(i);
......@@ -620,10 +645,16 @@ class OpConverter {
}
void SetEngine(TensorRTEngine* engine) { engine_ = engine; }
void SetBlockDesc(const framework::proto::BlockDesc* block) {
block_ = block;
}
virtual ~OpConverter() {}
// TensorRT engine
TensorRTEngine* engine_{nullptr};
// BlockDesc
const framework::proto::BlockDesc* block_{nullptr};
protected:
bool test_mode_;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class custom_op_plugin : public nvinfer1::IPluginV2 {
public:
explicit custom_op_plugin(float float_attr) { float_attr_ = float_attr; }
custom_op_plugin(const void* buffer, size_t length) {
DeserializeValue(&buffer, &length, &float_attr_);
}
size_t getSerializationSize() const noexcept override {
return SerializedSize(float_attr_);
}
void serialize(void* buffer) const noexcept override {
SerializeValue(&buffer, float_attr_);
}
nvinfer1::IPluginV2* clone() const noexcept override {
return new custom_op_plugin(float_attr_);
}
~custom_op_plugin() override = default;
const char* getPluginType() const noexcept override {
return "custom_op_paddle_trt_plugin";
}
const char* getPluginVersion() const noexcept override { return "1"; }
int getNbOutputs() const noexcept override { return 1; }
nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims* inputs,
int nbInputDims) noexcept override {
return inputs[0];
}
bool supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const noexcept override {
return true;
}
void configureWithFormat(nvinfer1::Dims const* inputDims,
int32_t nbInputs,
nvinfer1::Dims const* outputDims,
int32_t nbOutputs,
nvinfer1::DataType type,
nvinfer1::PluginFormat format,
int32_t maxBatchSize) noexcept override {}
int initialize() noexcept override { return 0; }
void terminate() noexcept override {}
size_t getWorkspaceSize(int maxBatchSize) const noexcept override {
return 0;
}
#if IS_TRT_VERSION_LT(8000)
int enqueue(int batch_size,
const void* const* inputs,
void** outputs,
#else
int enqueue(int batch_size,
const void* const* inputs,
void* const* outputs,
#endif
void* workspace,
cudaStream_t stream) noexcept override {
return 0;
}
void destroy() noexcept override { delete this; }
void setPluginNamespace(const char* libNamespace) noexcept override {
namespace_ = libNamespace;
}
const char* getPluginNamespace() const noexcept override {
return namespace_.c_str();
}
private:
float float_attr_;
std::string namespace_;
};
class custom_op_plugin_creator : public nvinfer1::IPluginCreator {
public:
custom_op_plugin_creator() {}
~custom_op_plugin_creator() override = default;
const char* getPluginName() const noexcept override {
return "custom_op_paddle_trt_plugin";
}
const char* getPluginVersion() const noexcept override { return "1"; }
void setPluginNamespace(const char* pluginNamespace) noexcept override {
plugin_namespace_ = pluginNamespace;
}
const char* getPluginNamespace() const noexcept override {
return plugin_namespace_.c_str();
}
const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override {
return nullptr;
}
nvinfer1::IPluginV2* createPlugin(
const char* name,
const nvinfer1::PluginFieldCollection* fc) noexcept override {
CHECK_EQ(fc->nbFields, 7);
// float_attr
auto attr_field = (fc->fields)[0];
CHECK(attr_field.type == nvinfer1::PluginFieldType::kFLOAT32);
CHECK_EQ(attr_field.length, 1);
float float_value = (reinterpret_cast<const float*>(attr_field.data))[0];
CHECK_EQ(float_value, 1.0);
// int_attr
attr_field = (fc->fields)[1];
CHECK(attr_field.type == nvinfer1::PluginFieldType::kINT32);
CHECK_EQ(attr_field.length, 1);
int int_value = (reinterpret_cast<const int*>(attr_field.data))[0];
CHECK_EQ(int_value, 1);
// bool_attr
attr_field = (fc->fields)[2];
CHECK(attr_field.type == nvinfer1::PluginFieldType::kINT32);
CHECK_EQ(attr_field.length, 1);
int bool_value = (reinterpret_cast<const int*>(attr_field.data))[0];
CHECK_EQ(bool_value, 1);
// string_attr
attr_field = (fc->fields)[3];
CHECK(attr_field.type == nvinfer1::PluginFieldType::kCHAR);
std::string expect_string_attr = "test_string_attr";
CHECK_EQ((size_t)attr_field.length, expect_string_attr.size() + 1);
const char* receive_string_attr =
reinterpret_cast<const char*>(attr_field.data);
CHECK(expect_string_attr == std::string(receive_string_attr));
// ints_attr
attr_field = (fc->fields)[4];
CHECK(attr_field.type == nvinfer1::PluginFieldType::kINT32);
CHECK_EQ(attr_field.length, 3);
const int* ints_value = reinterpret_cast<const int*>(attr_field.data);
CHECK_EQ(ints_value[0], 1);
CHECK_EQ(ints_value[1], 2);
CHECK_EQ(ints_value[2], 3);
// floats_attr
attr_field = (fc->fields)[5];
CHECK(attr_field.type == nvinfer1::PluginFieldType::kFLOAT32);
CHECK_EQ(attr_field.length, 3);
const float* floats_value = reinterpret_cast<const float*>(attr_field.data);
CHECK_EQ(floats_value[0], 1.0);
CHECK_EQ(floats_value[1], 2.0);
CHECK_EQ(floats_value[2], 3.0);
// bools_attr
attr_field = (fc->fields)[6];
CHECK(attr_field.type == nvinfer1::PluginFieldType::kINT32);
CHECK_EQ(attr_field.length, 3);
ints_value = reinterpret_cast<const int*>(attr_field.data);
CHECK_EQ(ints_value[0], true);
CHECK_EQ(ints_value[1], false);
CHECK_EQ(ints_value[2], true);
return new custom_op_plugin(float_value);
}
nvinfer1::IPluginV2* deserializePlugin(
const char* name,
const void* serialData,
size_t serialLength) noexcept override {
return new custom_op_plugin(serialData, serialLength);
}
private:
std::string plugin_namespace_;
};
class custom_op_dynamic_plugin : public nvinfer1::IPluginV2DynamicExt {
public:
explicit custom_op_dynamic_plugin(float float_attr)
: float_attr_(float_attr) {}
custom_op_dynamic_plugin(const void* buffer, size_t length) {
DeserializeValue(&buffer, &length, &float_attr_);
}
~custom_op_dynamic_plugin() override = default;
const char* getPluginType() const noexcept override {
return "custom_op_paddle_trt_dynamic_plugin";
}
const char* getPluginVersion() const noexcept override { return "1"; }
int getNbOutputs() const noexcept override { return 1; }
int initialize() noexcept override { return 0; }
void terminate() noexcept override {}
size_t getSerializationSize() const noexcept override {
return SerializedSize(float_attr_);
}
void serialize(void* buffer) const noexcept override {
SerializeValue(&buffer, float_attr_);
}
void destroy() noexcept override { delete this; }
void setPluginNamespace(const char* libNamespace) noexcept override {
namespace_ = libNamespace;
}
const char* getPluginNamespace() const noexcept override {
return namespace_.c_str();
}
/*IPluginV2Ext method*/
nvinfer1::DataType getOutputDataType(
int32_t index,
nvinfer1::DataType const* inputTypes,
int32_t nbInputs) const noexcept override {
return inputTypes[index];
}
/*IPluginV2DynamicExt method*/
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override {
return new custom_op_dynamic_plugin(float_attr_);
};
nvinfer1::DimsExprs getOutputDimensions(
int32_t outputIndex,
const nvinfer1::DimsExprs* inputs,
int32_t nbInputs,
nvinfer1::IExprBuilder& exprBuilder) noexcept override {
return inputs[0];
}
bool supportsFormatCombination(int32_t pos,
const nvinfer1::PluginTensorDesc* inOut,
int32_t nbInputs,
int32_t nbOutputs) noexcept override {
return true;
}
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int32_t nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int32_t nbOutputs) noexcept override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int32_t nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int32_t nbOutputs) const noexcept override {
return 0;
}
int32_t enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) noexcept override {
return 0;
}
private:
float float_attr_ = 0;
std::string namespace_;
};
class custom_op_dynamic_plugin_creator : public nvinfer1::IPluginCreator {
public:
custom_op_dynamic_plugin_creator() {}
~custom_op_dynamic_plugin_creator() override = default;
const char* getPluginName() const noexcept override {
return "custom_op_paddle_trt_dynamic_plugin";
}
const char* getPluginVersion() const noexcept override { return "1"; }
void setPluginNamespace(char const* pluginNamespace) noexcept override {
plugin_namespace_ = pluginNamespace;
}
const char* getPluginNamespace() const noexcept override {
return plugin_namespace_.c_str();
}
const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override {
return nullptr;
}
nvinfer1::IPluginV2* createPlugin(
const char* name,
const nvinfer1::PluginFieldCollection* fc) noexcept override {
return new custom_op_dynamic_plugin(1.0);
}
nvinfer1::IPluginV2* deserializePlugin(
const char* name,
const void* serialData,
size_t serialLength) noexcept override {
return new custom_op_dynamic_plugin(serialData, serialLength);
}
private:
std::string plugin_namespace_;
};
REGISTER_TRT_PLUGIN_V2(custom_op_plugin_creator);
REGISTER_TRT_PLUGIN_V2(custom_op_dynamic_plugin_creator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h> // NOLINT
#include "paddle/extension.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/test_custom_op_plugin.h"
PD_BUILD_OP(custom_op)
.Inputs({"Input"})
.Outputs({"Output"})
.Attrs({
"float_attr",
"int_attr",
"bool_attr",
"string_attr",
"ints_attr",
"floats_attr",
"bools_attr",
});
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(CustomPluginCreater, StaticShapePlugin) {
framework::ProgramDesc prog;
auto *block = prog.MutableBlock(0);
auto *op = block->AppendOp();
framework::proto::OpDesc *op_desc = op->Proto();
op_desc->set_type("custom_op");
auto *input_var = op_desc->add_inputs();
input_var->set_parameter("Input");
*input_var->add_arguments() = "X";
auto *output_var = op_desc->add_outputs();
output_var->set_parameter("Output");
*output_var->add_arguments() = "Out";
auto *attr = op_desc->add_attrs();
attr->set_name("float_attr");
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(1.0);
attr = op_desc->add_attrs();
attr->set_name("int_attr");
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(1);
attr = op_desc->add_attrs();
attr->set_name("bool_attr");
attr->set_type(paddle::framework::proto::AttrType::BOOLEAN);
attr->set_b(true);
attr = op_desc->add_attrs();
attr->set_name("string_attr");
attr->set_type(paddle::framework::proto::AttrType::STRING);
attr->set_s("test_string_attr");
attr = op_desc->add_attrs();
attr->set_name("ints_attr");
attr->set_type(paddle::framework::proto::AttrType::INTS);
attr->add_ints(1);
attr->add_ints(2);
attr->add_ints(3);
attr = op_desc->add_attrs();
attr->set_name("floats_attr");
attr->set_type(paddle::framework::proto::AttrType::FLOATS);
attr->add_floats(1.0);
attr->add_floats(2.0);
attr->add_floats(3.0);
attr = op_desc->add_attrs();
attr->set_name("bools_attr");
attr->set_type(paddle::framework::proto::AttrType::BOOLEANS);
attr->add_bools(true);
attr->add_bools(false);
attr->add_bools(true);
// init trt engine
std::unique_ptr<TensorRTEngine> engine_;
engine_.reset(new TensorRTEngine(5, 1 << 15));
engine_->InitNetwork();
engine_->DeclareInput(
"X", nvinfer1::DataType::kFLOAT, nvinfer1::Dims3(2, 5, 5));
framework::Scope scope;
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
auto &custom_plugin_tell = OpTeller::Global().GetCustomPluginTeller();
framework::OpDesc custom_op(*op_desc, nullptr);
CHECK_EQ((*custom_plugin_tell)(custom_op, false, false), true);
OpTeller::Global().SetOpConverterType("custom_op",
OpConverterType::CustomPluginCreater);
OpConverter converter;
converter.ConvertBlock(
*block->Proto(), {}, scope, engine_.get() /*TensorRTEngine*/);
}
TEST(CustomPluginCreater, DynamicShapePlugin) {
framework::ProgramDesc prog;
auto *block = prog.MutableBlock(0);
auto *op = block->AppendOp();
framework::proto::OpDesc *op_desc = op->Proto();
op_desc->set_type("custom_op");
auto *input_var = op_desc->add_inputs();
input_var->set_parameter("Input");
*input_var->add_arguments() = "X";
auto *output_var = op_desc->add_outputs();
output_var->set_parameter("Output");
*output_var->add_arguments() = "Out";
auto *attr = op_desc->add_attrs();
attr->set_name("float_attr");
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr = op_desc->add_attrs();
attr->set_name("int_attr");
attr->set_type(paddle::framework::proto::AttrType::INT);
attr = op_desc->add_attrs();
attr->set_name("bool_attr");
attr->set_type(paddle::framework::proto::AttrType::BOOLEAN);
attr = op_desc->add_attrs();
attr->set_name("string_attr");
attr->set_type(paddle::framework::proto::AttrType::STRING);
attr = op_desc->add_attrs();
attr->set_name("ints_attr");
attr->set_type(paddle::framework::proto::AttrType::INTS);
attr = op_desc->add_attrs();
attr->set_name("floats_attr");
attr->set_type(paddle::framework::proto::AttrType::FLOATS);
attr = op_desc->add_attrs();
attr->set_name("bools_attr");
attr->set_type(paddle::framework::proto::AttrType::BOOLEANS);
// init trt engine
std::unique_ptr<TensorRTEngine> engine_;
std::map<std::string, std::vector<int>> min_input_shape = {
{"x", {1, 2, 5, 5}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"x", {1, 2, 5, 5}}};
std::map<std::string, std::vector<int>> optim_input_shape = {
{"x", {1, 2, 5, 5}}};
engine_.reset(new TensorRTEngine(5,
1 << 15,
AnalysisConfig::Precision::kFloat32,
nullptr,
0,
min_input_shape,
max_input_shape,
optim_input_shape));
engine_->InitNetwork();
LOG(INFO) << "with_dynamic_shape " << engine_->with_dynamic_shape();
engine_->DeclareInput(
"X", nvinfer1::DataType::kFLOAT, nvinfer1::Dims4(-1, 2, 5, 5));
framework::Scope scope;
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
auto &custom_plugin_tell = OpTeller::Global().GetCustomPluginTeller();
framework::OpDesc custom_op(*op_desc, nullptr);
CHECK_EQ((*custom_plugin_tell)(custom_op, false, true), true);
OpTeller::Global().SetOpConverterType("custom_op",
OpConverterType::CustomPluginCreater);
OpConverter converter;
converter.ConvertBlock(
*block->Proto(), {}, scope, engine_.get() /*TensorRTEngine*/);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_TRT_CONVERTER(custom_plugin_creater)
......@@ -57,6 +57,7 @@ TEST(OpConverter, ConvertBlock) {
x_tensor->Resize(phi::make_ddim(dim_vec));
x_tensor->mutable_data<float>(platform::CUDAPlace(0));
OpTeller::Global().SetOpConverterType("conv2d", OpConverterType::Default);
OpConverter converter;
converter.ConvertBlock(
*block->Proto(), {"conv2d-Y"}, scope, engine_.get() /*TensorRTEngine*/);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_factory.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
namespace paddle {
namespace inference {
namespace tensorrt {
nvinfer1::DimsExprs GatherNdInferMeta(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder, // NOLINT
const framework::OpDesc& op_desc) {
const nvinfer1::DimsExprs x_dims = inputs[0];
const int x_dims_size = inputs[0].nbDims;
const nvinfer1::DimsExprs index_dims = inputs[1];
const int index_dims_size = inputs[1].nbDims;
std::vector<const nvinfer1::IDimensionExpr*> result_dims;
// The result dims is
// Index.shape[:-1] + X.shape[Index.shape[-1]:]
for (int i = 0; i < index_dims_size - 1; ++i) {
result_dims.emplace_back(index_dims.d[i]);
}
if (index_dims.d[index_dims_size - 1]->isConstant()) {
for (int i = index_dims.d[index_dims_size - 1]->getConstantValue();
i < x_dims_size;
++i) {
result_dims.emplace_back(x_dims.d[i]);
}
}
nvinfer1::DimsExprs output;
output.nbDims = result_dims.size();
for (int i = 0; i < output.nbDims; i++) {
output.d[i] = result_dims[i];
}
return output;
}
PD_REGISTER_DYNAMIC_INFER_META_FN(gather_nd, GatherNdInferMeta);
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <NvInfer.h>
#include <string>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/macros.h"
#include "paddle/utils/flat_hash_map.h"
namespace paddle {
namespace inference {
namespace tensorrt {
using DynamicMetaFn =
nvinfer1::DimsExprs (*)(int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder, // NOLINT
const framework::OpDesc& op_desc);
class DynamicMetaFnFactory {
public:
static DynamicMetaFnFactory& Instance() {
static DynamicMetaFnFactory g_meta_fn_map;
return g_meta_fn_map;
}
bool Contains(const std::string& op_name) const {
return meta_fn_map_.count(op_name) > 0;
}
void Insert(std::string op_name, DynamicMetaFn infer_meta_fn) {
PADDLE_ENFORCE_NE(
Contains(op_name),
true,
phi::errors::AlreadyExists(
"`%s` op's DynamicInferMetaFn has been registered.", op_name));
meta_fn_map_.insert({std::move(op_name), std::move(infer_meta_fn)});
}
const DynamicMetaFn& Get(const std::string& op_name) const {
auto it = meta_fn_map_.find(op_name);
PADDLE_ENFORCE_NE(
it,
meta_fn_map_.end(),
phi::errors::NotFound(
"`%s` op's DynamicInferMetaFn has been registered.", op_name));
return it->second;
}
private:
DynamicMetaFnFactory() = default;
paddle::flat_hash_map<std::string, DynamicMetaFn> meta_fn_map_;
DISABLE_COPY_AND_ASSIGN(DynamicMetaFnFactory);
};
struct DynamicMetaFnRegistrar {
DynamicMetaFnRegistrar(const char* op_name, DynamicMetaFn infer_meta_fn) {
DynamicMetaFnFactory::Instance().Insert(op_name, std::move(infer_meta_fn));
}
static void Touch() {}
};
#define PD_REGISTER_DYNAMIC_INFER_META_FN(op_name, dynamic_infer_meta_fn) \
static paddle::inference::tensorrt::DynamicMetaFnRegistrar \
registrar_dynamic_infer_meta_fn_for_##op_name(#op_name, \
dynamic_infer_meta_fn); \
int TouchDynamicMetaFnRegistrar_##op_name() { \
registrar_dynamic_infer_meta_fn_for_##op_name.Touch(); \
return 0; \
}
#define USE_TRT_DYNAMIC_INFER_META_FN(op_name) \
extern int TouchDynamicMetaFnRegistrar_##op_name(); \
static int use_op_dynamic_infer_meta##op_name UNUSED = \
TouchDynamicMetaFnRegistrar_##op_name();
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
USE_TRT_DYNAMIC_INFER_META_FN(gather_nd);
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -18,6 +18,11 @@
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_meta_info_helper.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_factory.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_factory.h"
namespace paddle {
namespace framework {
......@@ -60,252 +65,16 @@ struct SimpleOpTypeSetTeller : public Teller {
#endif
}
bool operator()(const std::string& op_type,
const framework::OpDesc& desc,
bool use_no_calib_int8) override {
if (use_no_calib_int8) {
return int8_teller_set.count(op_type);
} else {
return teller_set.count(op_type);
}
}
private:
// use this set for no calib int8.
std::unordered_set<std::string> int8_teller_set{
"mul",
"matmul",
"conv2d",
"conv2d_fusion",
"pool2d",
"relu",
"elu",
"selu",
"softsign",
"softplus",
"stanh",
"thresholded_relu",
"exp",
"log",
"sqrt",
"abs",
"sin",
"cos",
"tan",
"sinh",
"cosh",
"asin",
"acos",
"atan",
"asinh",
"atanh",
"ceil",
"floor",
"erf",
"softmax",
"sigmoid",
"hard_swish",
"depthwise_conv2d",
"batch_norm",
"concat",
"tanh",
"pad",
"elementwise_add",
"elementwise_sub",
"elementwise_mul",
"elementwise_div",
"elementwise_pow",
"equal",
"dropout",
"prelu",
"conv2d_transpose",
"depthwise_conv2d_transpose",
"leaky_relu",
"fc",
"shuffle_channel",
"swish",
"silu",
"split",
"instance_norm",
"gelu",
"layer_norm",
"scale",
"stack",
"transpose2",
"transpose",
"top_k",
"top_k_v2",
"flatten2",
"flatten",
"gather",
"gather_nd",
"yolo_box",
"yolo_box_head",
"arg_max",
"roi_align",
"affine_channel",
"nearest_interp",
"anchor_generator",
"reduce_sum",
"reduce_mean",
"conv3d",
"conv3d_transpose",
"mish",
"nearest_interp_v2",
"bilinear_interp_v2",
"pool3d",
"deformable_conv",
"relu6",
"hard_sigmoid",
"clip",
"fused_embedding_eltwise_layernorm",
"multihead_matmul",
"skip_layernorm",
"slice",
"strided_slice",
"fused_preln_embedding_eltwise_layernorm",
"preln_residual_bias",
"c_allreduce_sum",
"c_allreduce_min",
"c_allreduce_max",
"c_allreduce_prod",
"roll",
"cast",
"preln_skip_layernorm",
"transformer_input_convert",
"recover_padding",
"remove_padding",
"fill_constant",
"sum",
"shape",
"squeeze2",
"unsqueeze2",
"layernorm_shift_partition"};
std::unordered_set<std::string> teller_set{
"mul",
"matmul",
"conv2d",
"conv2d_fusion",
"pool2d",
"relu",
"elu",
"selu",
"softsign",
"softplus",
"stanh",
"thresholded_relu",
"exp",
"log",
"sqrt",
"abs",
"sin",
"cos",
"tan",
"sinh",
"cosh",
"asin",
"acos",
"atan",
"asinh",
"atanh",
"ceil",
"floor",
"erf",
"softmax",
"sigmoid",
"hard_swish",
"depthwise_conv2d",
"batch_norm",
"concat",
"tanh",
"pad",
"elementwise_add",
"elementwise_sub",
"elementwise_mul",
"elementwise_div",
"elementwise_pow",
"equal",
"dropout",
"prelu",
"conv2d_transpose",
"depthwise_conv2d_transpose",
"leaky_relu",
"fc",
"shuffle_channel",
"swish",
"silu",
"split",
"instance_norm",
"gelu",
"layer_norm",
"scale",
"stack",
"transpose2",
"transpose",
"top_k",
"top_k_v2",
"flatten2",
"flatten",
"gather",
"gather_nd",
"yolo_box",
"yolo_box_head",
"arg_max",
"roi_align",
"affine_channel",
"nearest_interp",
"anchor_generator",
"reduce_sum",
"reduce_mean",
"conv3d",
"conv3d_transpose",
"mish",
"bilinear_interp_v2",
"nearest_interp_v2",
"pool3d",
"deformable_conv",
"relu6",
"hard_sigmoid",
"clip",
"fused_embedding_eltwise_layernorm",
"multihead_matmul",
"skip_layernorm",
"slice",
"strided_slice",
"fused_preln_embedding_eltwise_layernorm",
"preln_skip_layernorm",
"preln_residual_bias",
"c_allreduce_sum",
"c_allreduce_min",
"c_allreduce_max",
"c_allreduce_prod",
"roll",
"cast",
"transformer_input_convert",
"recover_padding",
"remove_padding",
"fill_constant",
"sum",
"shape",
"squeeze2",
"unsqueeze2",
"fused_token_prune",
"layernorm_shift_partition"};
};
bool OpTeller::Tell(const framework::ir::Node* node,
bool use_no_calib_int8,
bool with_dynamic_shape) {
const std::string op_type = node->Op()->Type();
const framework::OpDesc desc = *node->Op();
bool operator()(const framework::OpDesc& desc,
bool use_no_calib_int8 = false,
bool with_dynamic_shape = false) override {
const std::string op_type = desc.Type();
// do not support the op which is labeled the `skip_quant`
if ((desc.HasAttr("namescope") &&
PADDLE_GET_CONST(std::string, desc.GetAttr("op_namescope")) ==
"/skip_quant_2/") ||
desc.HasAttr("skip_quant"))
return false;
for (auto& teller : tellers_) {
std::unordered_set<std::string> act_op_list = {
"relu", "relu6", "sigmoid",
"elu", "selu", "softsign",
......@@ -2300,13 +2069,329 @@ bool OpTeller::Tell(const framework::ir::Node* node,
}
}
if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
if (use_no_calib_int8) {
return int8_teller_set.count(op_type);
} else {
return teller_set.count(op_type);
}
}
private:
// use this set for no calib int8.
std::unordered_set<std::string> int8_teller_set{
"mul",
"matmul",
"conv2d",
"conv2d_fusion",
"pool2d",
"relu",
"elu",
"selu",
"softsign",
"softplus",
"stanh",
"thresholded_relu",
"exp",
"log",
"sqrt",
"abs",
"sin",
"cos",
"tan",
"sinh",
"cosh",
"asin",
"acos",
"atan",
"asinh",
"atanh",
"ceil",
"floor",
"erf",
"softmax",
"sigmoid",
"hard_swish",
"depthwise_conv2d",
"batch_norm",
"concat",
"tanh",
"pad",
"elementwise_add",
"elementwise_sub",
"elementwise_mul",
"elementwise_div",
"elementwise_pow",
"equal",
"dropout",
"prelu",
"conv2d_transpose",
"depthwise_conv2d_transpose",
"leaky_relu",
"fc",
"shuffle_channel",
"swish",
"silu",
"split",
"instance_norm",
"gelu",
"layer_norm",
"scale",
"stack",
"transpose2",
"transpose",
"top_k",
"top_k_v2",
"flatten2",
"flatten",
"gather",
"gather_nd",
"yolo_box",
"yolo_box_head",
"arg_max",
"roi_align",
"affine_channel",
"nearest_interp",
"anchor_generator",
"reduce_sum",
"reduce_mean",
"conv3d",
"conv3d_transpose",
"mish",
"nearest_interp_v2",
"bilinear_interp_v2",
"pool3d",
"deformable_conv",
"relu6",
"hard_sigmoid",
"clip",
"fused_embedding_eltwise_layernorm",
"multihead_matmul",
"skip_layernorm",
"slice",
"strided_slice",
"fused_preln_embedding_eltwise_layernorm",
"preln_residual_bias",
"c_allreduce_sum",
"c_allreduce_min",
"c_allreduce_max",
"c_allreduce_prod",
"roll",
"cast",
"preln_skip_layernorm",
"transformer_input_convert",
"recover_padding",
"remove_padding",
"fill_constant",
"sum",
"shape",
"squeeze2",
"unsqueeze2",
"layernorm_shift_partition"};
std::unordered_set<std::string> teller_set{
"mul",
"matmul",
"conv2d",
"conv2d_fusion",
"pool2d",
"relu",
"elu",
"selu",
"softsign",
"softplus",
"stanh",
"thresholded_relu",
"exp",
"log",
"sqrt",
"abs",
"sin",
"cos",
"tan",
"sinh",
"cosh",
"asin",
"acos",
"atan",
"asinh",
"atanh",
"ceil",
"floor",
"erf",
"softmax",
"sigmoid",
"hard_swish",
"depthwise_conv2d",
"batch_norm",
"concat",
"tanh",
"pad",
"elementwise_add",
"elementwise_sub",
"elementwise_mul",
"elementwise_div",
"elementwise_pow",
"equal",
"dropout",
"prelu",
"conv2d_transpose",
"depthwise_conv2d_transpose",
"leaky_relu",
"fc",
"shuffle_channel",
"swish",
"silu",
"split",
"instance_norm",
"gelu",
"layer_norm",
"scale",
"stack",
"transpose2",
"transpose",
"top_k",
"top_k_v2",
"flatten2",
"flatten",
"gather",
"gather_nd",
"yolo_box",
"yolo_box_head",
"arg_max",
"roi_align",
"affine_channel",
"nearest_interp",
"anchor_generator",
"reduce_sum",
"reduce_mean",
"conv3d",
"conv3d_transpose",
"mish",
"bilinear_interp_v2",
"nearest_interp_v2",
"pool3d",
"deformable_conv",
"relu6",
"hard_sigmoid",
"clip",
"fused_embedding_eltwise_layernorm",
"multihead_matmul",
"skip_layernorm",
"slice",
"strided_slice",
"fused_preln_embedding_eltwise_layernorm",
"preln_skip_layernorm",
"preln_residual_bias",
"c_allreduce_sum",
"c_allreduce_min",
"c_allreduce_max",
"c_allreduce_prod",
"roll",
"cast",
"transformer_input_convert",
"recover_padding",
"remove_padding",
"fill_constant",
"sum",
"shape",
"squeeze2",
"unsqueeze2",
"fused_token_prune",
"layernorm_shift_partition"};
};
struct GenericPluginTeller : public Teller {
public:
GenericPluginTeller() {}
bool operator()(const framework::OpDesc& desc,
bool use_no_calib_int8 = false,
bool with_dynamic_shape = false) override {
const std::string op_type = desc.Type();
// only consider dynamic_shape mode
if (!with_dynamic_shape) {
return false;
}
if (use_no_calib_int8) {
return false;
} else {
framework::InitDefaultKernelSignatureMap();
bool res = phi::OpUtilsMap::Instance().HasArgumentMappingFn(op_type) ||
phi::DefaultKernelSignatureMap::Instance().Has(op_type);
if (!res) {
VLOG(3) << op_type << " has no KernelSignature";
return false;
}
res = phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type);
if (!res) {
VLOG(3) << op_type << " has no CompatiblePhiKernel in phi.";
return false;
}
auto& dynamic_infermeta_factory =
tensorrt::DynamicMetaFnFactory::Instance();
res = dynamic_infermeta_factory.Contains(op_type);
if (!res) {
VLOG(3) << op_type << " has no DynamicMetaFn.";
return false;
}
return true;
}
}
};
struct CustomPluginTeller : public Teller {
public:
CustomPluginTeller() {}
bool operator()(const framework::OpDesc& desc,
bool use_no_calib_int8 = false,
bool with_dynamic_shape = false) override {
const std::string op_type = desc.Type();
std::string expect_plugin_name;
if (with_dynamic_shape) {
expect_plugin_name = op_type + "_paddle_trt_dynamic_plugin";
} else {
expect_plugin_name = op_type + "_paddle_trt_plugin";
}
int num = 0;
auto creators = GetPluginRegistry()->getPluginCreatorList(&num);
for (int i = 0; i < num; i++) {
if (std::string(creators[i]->getPluginName()) == expect_plugin_name)
return true;
}
return false;
}
};
bool OpTeller::Tell(const framework::ir::Node* node,
bool use_no_calib_int8,
bool with_dynamic_shape) {
const std::string op_type = node->Op()->Type();
const framework::OpDesc desc = *node->Op();
auto& default_teller = GetDefaultTeller();
if ((*default_teller)(desc, use_no_calib_int8, with_dynamic_shape)) {
SetOpConverterType(op_type, OpConverterType::Default);
return true;
}
auto& generic_plugin_teller = GetGenericPluginTeller();
if ((*generic_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape)) {
SetOpConverterType(op_type, OpConverterType::GenericPluginCreater);
return true;
}
auto& custom_plugin_teller = GetCustomPluginTeller();
if ((*custom_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape)) {
SetOpConverterType(op_type, OpConverterType::CustomPluginCreater);
return true;
}
return false;
}
OpTeller::OpTeller() { tellers_.emplace_back(new SimpleOpTypeSetTeller); }
OpTeller::OpTeller() {
tellers_.emplace_back(new tensorrt::SimpleOpTypeSetTeller);
tellers_.emplace_back(new tensorrt::GenericPluginTeller);
tellers_.emplace_back(new tensorrt::CustomPluginTeller);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -38,9 +38,9 @@ namespace tensorrt {
* issues such as op_desc.
*/
struct Teller {
virtual bool operator()(const std::string& op_type,
const framework::OpDesc& desc,
bool use_no_calib_int8) = 0;
virtual bool operator()(const framework::OpDesc& desc,
bool use_no_calib_int8 = false,
bool with_dynamic_shape = false) = 0;
virtual ~Teller() = default;
};
......@@ -55,9 +55,15 @@ struct Teller {
*};
*/
enum class OpConverterType {
Default = 0,
GenericPluginCreater,
CustomPluginCreater
};
/*
* class OpTeller helps to tell whether a fluid
* operator can be transformed to a TensorRT layer.
* operator can be transformed to a TensorRT layer
* and use which kind of OpConverter
*/
class OpTeller {
public:
......@@ -70,11 +76,26 @@ class OpTeller {
bool use_no_calib_int8 = false,
bool with_dynamic_shape = false);
std::unique_ptr<Teller>& GetDefaultTeller() { return tellers_.at(0); }
std::unique_ptr<Teller>& GetGenericPluginTeller() { return tellers_.at(1); }
std::unique_ptr<Teller>& GetCustomPluginTeller() { return tellers_.at(2); }
void SetOpConverterType(std::string name, OpConverterType type) {
op_converter_type_map_[name] = type;
}
const std::map<std::string, OpConverterType>& GetOpConverterTypeMap() const {
return op_converter_type_map_;
}
private:
OpTeller();
private:
std::vector<std::unique_ptr<Teller>> tellers_;
std::map<std::string, OpConverterType> op_converter_type_map_;
};
} // namespace tensorrt
......
......@@ -32,7 +32,8 @@ list(
c_allreduce_op_plugin.cu
preln_residual_bias_plugin.cu
fused_token_prune_op_plugin.cu
layernorm_shift_partition_op.cu)
layernorm_shift_partition_op.cu
generic_plugin.cu)
if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8)
list(APPEND TRT_FILES spmm_plugin.cu)
......@@ -41,7 +42,13 @@ endif()
nv_library(
tensorrt_plugin
SRCS ${TRT_FILES}
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
DEPS enforce
tensorrt_engine
prelu
tensor
bert_encoder_functor
tensorrt_dynamic_shape_infermeta_factory
tensorrt_plugin_arg_mapping_context)
nv_test(
test_split_plugin
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/generic_plugin.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_registry.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
void BuildPhiKernelContextAttr(const framework::OpDesc& op_desc,
phi::KernelContext* kernel_context,
const phi::KernelSignature& signature,
const phi::Kernel& phi_kernel) {
const phi::KernelArgsDef& args_def = phi_kernel.args_def();
const auto& attr_names = signature.attr_names;
const auto& attr_defs = args_def.attribute_defs();
PADDLE_ENFORCE_EQ(
attr_names.size(),
attr_defs.size(),
platform::errors::InvalidArgument(
"The attr_names.size() should be equal to attr_defs.size()."));
framework::AttrReader attr_reader(op_desc.GetAttrMap());
for (size_t k = 0; k < attr_names.size(); ++k) {
auto attr_name = attr_names[k];
auto* attr_ptr = attr_reader.GetAttr(attr_name);
if (attr_ptr) {
switch (attr_defs[k].type_index) {
case phi::AttributeType::SCALAR: {
auto& attr = *attr_ptr;
switch (AttrTypeID(attr)) {
case framework::proto::AttrType::FLOAT:
return kernel_context->EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(float, attr)));
break;
case framework::proto::AttrType::INT:
return kernel_context->EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(int, attr)));
break;
case framework::proto::AttrType::STRING:
return kernel_context->EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(std::string, attr)));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when "
"ProtoAttr2PhiAttr.",
attr_name));
}
} break;
case phi::AttributeType::INT_ARRAY: {
auto& attr = *attr_ptr;
switch (AttrTypeID(attr)) {
case framework::proto::AttrType::INTS:
kernel_context->EmplaceBackAttr(std::move(
phi::IntArray(PADDLE_GET_CONST(std::vector<int32_t>, attr))));
break;
case framework::proto::AttrType::LONGS:
kernel_context->EmplaceBackAttr(std::move(
phi::IntArray(PADDLE_GET_CONST(std::vector<int64_t>, attr))));
break;
case framework::proto::AttrType::INT:
kernel_context->EmplaceBackAttr(
phi::IntArray({PADDLE_GET_CONST(int, attr)}));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to IntArray when "
"ProtoAttr2PhiAttr.",
attr_name));
}
} break;
case phi::AttributeType::SCALARS: {
auto& attr = *attr_ptr;
switch (AttrTypeID(attr)) {
case framework::proto::AttrType::INTS: {
const auto& vec = PADDLE_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_context->EmplaceBackAttr(std::move(scalar_list));
} break;
case framework::proto::AttrType::LONGS: {
const auto& vec = PADDLE_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_context->EmplaceBackAttr(std::move(scalar_list));
} break;
case framework::proto::AttrType::FLOATS: {
const auto& vec = PADDLE_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_context->EmplaceBackAttr(std::move(scalar_list));
} break;
case framework::proto::AttrType::FLOAT64S: {
const auto& vec = PADDLE_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_context->EmplaceBackAttr(std::move(scalar_list));
} break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
"ProtoAttr2PhiAttr.",
attr_name));
}
} break;
default: {
auto& attr = *attr_ptr;
switch (attr_defs[k].type_index) {
case phi::AttributeType::FLOAT32:
kernel_context->EmplaceBackAttr(PADDLE_GET_CONST(float, attr));
break;
case phi::AttributeType::INT32:
kernel_context->EmplaceBackAttr(PADDLE_GET_CONST(int, attr));
break;
case phi::AttributeType::BOOL:
kernel_context->EmplaceBackAttr(PADDLE_GET_CONST(bool, attr));
break;
case phi::AttributeType::INT64:
kernel_context->EmplaceBackAttr(PADDLE_GET_CONST(int64_t, attr));
break;
case phi::AttributeType::INT32S:
kernel_context->EmplaceBackAttr(
PADDLE_GET_CONST(std::vector<int>, attr));
break;
case phi::AttributeType::DATA_TYPE: {
auto data_type = paddle::framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>(
PADDLE_GET_CONST(int, attr)));
kernel_context->EmplaceBackAttr(data_type);
} break;
case phi::AttributeType::STRING:
kernel_context->EmplaceBackAttr(
PADDLE_GET_CONST(std::string, attr));
break;
case phi::AttributeType::INT64S:
switch (AttrTypeID(attr)) {
case framework::proto::AttrType::LONGS:
kernel_context->EmplaceBackAttr(
PADDLE_GET_CONST(std::vector<int64_t>, attr));
break;
case framework::proto::AttrType::INTS: {
const auto& vector_int_attr =
PADDLE_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(
vector_int_attr.begin(), vector_int_attr.end());
kernel_context->EmplaceBackAttr(vector_int64_attr);
} break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<int64_t> "
"when ProtoAttr2PhiAttr.",
attr_name));
}
break;
case phi::AttributeType::FLOAT32S:
kernel_context->EmplaceBackAttr(
PADDLE_GET_CONST(std::vector<float>, attr));
break;
case phi::AttributeType::STRINGS:
kernel_context->EmplaceBackAttr(
PADDLE_GET_CONST(std::vector<std::string>, attr));
break;
case phi::AttributeType::BOOLS:
kernel_context->EmplaceBackAttr(
PADDLE_GET_CONST(std::vector<bool>, attr));
break;
case phi::AttributeType::FLOAT64S:
kernel_context->EmplaceBackAttr(
PADDLE_GET_CONST(std::vector<double>, attr));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct "
"ProtoAttr2PhiAttr.",
attr_name));
}
}
}
}
}
}
GenericPlugin::GenericPlugin(
const paddle::framework::proto::OpDesc& proto_op_desc,
const InputOutPutVarInfo& in_out_info) {
proto_op_desc_ = proto_op_desc;
op_desc_ = std::move(framework::OpDesc(proto_op_desc_, nullptr));
proto_op_desc_.SerializeToString(&op_meta_data_);
inputs_data_type_ = in_out_info.inputs_data_type;
outputs_data_type_ = in_out_info.outputs_data_type;
}
GenericPlugin::GenericPlugin(
const paddle::framework::proto::OpDesc& proto_op_desc,
const std::vector<int>& inputs_data_type,
const std::vector<int>& outputs_data_type) {
proto_op_desc_ = proto_op_desc;
op_desc_ = std::move(framework::OpDesc(proto_op_desc_, nullptr));
proto_op_desc_.SerializeToString(&op_meta_data_);
inputs_data_type_ = inputs_data_type;
outputs_data_type_ = outputs_data_type;
}
GenericPlugin::GenericPlugin(void const* serial_data, size_t serial_length) {
DeserializeValue(&serial_data, &serial_length, &inputs_data_type_);
DeserializeValue(&serial_data, &serial_length, &outputs_data_type_);
std::string op_meta_data((char*)(serial_data), serial_length); // NOLINT
op_meta_data_ = std::move(op_meta_data);
proto_op_desc_.ParseFromString(op_meta_data_);
op_desc_ = std::move(framework::OpDesc(proto_op_desc_, nullptr));
}
int GenericPlugin::getNbOutputs() const TRT_NOEXCEPT {
int res = 0;
for (auto& i : op_desc_.Outputs()) {
if (!i.second.empty()) res += i.second.size();
}
return res;
}
int GenericPlugin::getNbInputs() const TRT_NOEXCEPT {
int res = 0;
for (auto& i : op_desc_.Inputs()) {
if (!i.second.empty()) res += i.second.size();
}
return res;
}
nvinfer1::IPluginV2DynamicExt* GenericPlugin::clone() const TRT_NOEXCEPT {
nvinfer1::IPluginV2DynamicExt* plugin =
new GenericPlugin(proto_op_desc_, inputs_data_type_, outputs_data_type_);
plugin->initialize();
return plugin;
}
void GenericPlugin::serialize(void* buffer) const TRT_NOEXCEPT {
// inputs_data_type_
SerializeValue(&buffer, inputs_data_type_);
// outputs_data_type_
SerializeValue(&buffer, outputs_data_type_);
// serialize op_meta_data_
std::memcpy(buffer, op_meta_data_.c_str(), op_meta_data_.size());
reinterpret_cast<char*&>(buffer) += op_meta_data_.size();
}
bool GenericPlugin::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc* in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
return true;
}
nvinfer1::DataType GenericPlugin::getOutputDataType(
int index,
const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT {
return input_types[0];
}
int GenericPlugin::initialize() TRT_NOEXCEPT {
std::string op_type = op_desc_.Type();
phi::KernelSignature phi_kernel_signature;
if (phi::OpUtilsMap::Instance().HasArgumentMappingFn(op_type)) {
const phi::ArgumentMappingFn* argument_mapping_func =
phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type);
PluginArgumentMappingContext argument_mapping_context(&op_desc_);
phi_kernel_signature = (*argument_mapping_func)(argument_mapping_context);
} else {
phi_kernel_signature =
phi::DefaultKernelSignatureMap::Instance().Get(op_type);
}
phi::KernelKey phi_kernel_key(
phi::Backend::GPU, phi::DataLayout::ANY, phi::DataType::FLOAT32);
PADDLE_ENFORCE_EQ(
phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type),
true,
platform::errors::Fatal("%s has no compatible phi kernel!",
op_type.c_str()));
const phi::Kernel& phi_kernel = phi::KernelFactory::Instance().SelectKernel(
phi_kernel_signature.name, phi_kernel_key);
phi_kernel_ = &phi_kernel;
PADDLE_ENFORCE_EQ(phi_kernel_->IsValid(),
true,
platform::errors::Fatal("%s phi kernel is invalid!.",
phi_kernel_signature.name));
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
platform::CUDAPlace place(platform::GetCurrentDeviceId());
auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(place));
phi_kernel_context_ = new phi::KernelContext(dev_ctx);
dense_tensor_inputs_ = new std::vector<phi::DenseTensor>(getNbInputs());
dense_tensor_outputs_ = new std::vector<phi::DenseTensor>(getNbOutputs());
BuildPhiKernelContextAttr(
op_desc_, phi_kernel_context_, phi_kernel_signature, phi_kernel);
return 0;
}
nvinfer1::DimsExprs GenericPlugin::getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT {
CHECK(output_index < getNbOutputs());
auto& dynamic_infermeta_factory = tensorrt::DynamicMetaFnFactory::Instance();
PADDLE_ENFORCE_EQ(dynamic_infermeta_factory.Contains(op_desc_.Type()),
true,
platform::errors::InvalidArgument(
"The %s op has no dynamic plugin infershape function!",
op_desc_.Type().c_str()));
auto* infershape_func = dynamic_infermeta_factory.Get(op_desc_.Type());
return infershape_func(
output_index, inputs, nb_inputs, expr_builder, op_desc_);
}
void GenericPlugin::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* in,
int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nb_outputs) TRT_NOEXCEPT {
CHECK(phi_kernel_context_);
CHECK(phi_kernel_);
CHECK(nb_inputs == getNbInputs());
CHECK(nb_outputs == getNbOutputs());
}
// Shutdown the layer. This is called when the engine is destroyed
void GenericPlugin::terminate() TRT_NOEXCEPT {
delete phi_kernel_context_;
delete dense_tensor_inputs_;
delete dense_tensor_outputs_;
}
int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT {
platform::CUDAPlace place(platform::GetCurrentDeviceId());
// [TODO]now generic plugin do not support FP16 and INT8 precision
auto protoType2PhiType = [](int proto_type) -> phi::DataType {
if (proto_type ==
static_cast<int>(framework::proto::VarType_Type::VarType_Type_FP32))
return phi::DataType::FLOAT32;
else if (proto_type ==
static_cast<int>(
framework::proto::VarType_Type::VarType_Type_INT64) ||
proto_type ==
static_cast<int>(
framework::proto::VarType_Type::VarType_Type_INT32))
return phi::DataType::INT32;
else if (proto_type ==
static_cast<int>(
framework::proto::VarType_Type::VarType_Type_BOOL))
return phi::DataType::BOOL;
else
CHECK(false) << "precision is not supported";
};
// input
for (int i = 0; i < getNbInputs(); i++) {
auto const& input_dims = input_desc[i].dims;
std::vector<int> input_shape;
for (int j = 0; j < input_dims.nbDims; j++)
input_shape.push_back(input_dims.d[j]);
int input_numel = 1;
for (int k = 0; k < input_shape.size(); k++) input_numel *= input_shape[k];
phi::DenseTensorMeta input_meta(protoType2PhiType(inputs_data_type_[i]),
phi::make_ddim(input_shape));
std::shared_ptr<phi::Allocation> input_alloc(
new phi::Allocation((void*)(inputs[i]), // NOLINT
input_numel * sizeof(int32_t),
place));
(*dense_tensor_inputs_)[i] =
std::move(phi::DenseTensor(input_alloc, input_meta));
phi_kernel_context_->EmplaceBackInput(&((*dense_tensor_inputs_)[i]));
}
// output
for (int i = 0; i < getNbOutputs(); i++) {
auto const& output_dims = output_desc[i].dims;
std::vector<int> output_shape;
for (int j = 0; j < output_dims.nbDims; j++)
output_shape.push_back(output_dims.d[j]);
int output_numel = 1;
for (int k = 0; k < output_shape.size(); k++)
output_numel *= output_shape[k];
phi::DenseTensorMeta output_meta(protoType2PhiType(outputs_data_type_[i]),
phi::make_ddim(output_shape));
std::shared_ptr<phi::Allocation> output_alloc(
new phi::Allocation(reinterpret_cast<void*>(outputs[i]),
output_numel * sizeof(float),
place));
phi::DenseTensor output_densetonsor(output_alloc, output_meta);
(*dense_tensor_outputs_)[i] =
std::move(phi::DenseTensor(output_alloc, output_meta));
phi_kernel_context_->EmplaceBackOutput(&((*dense_tensor_outputs_)[i]));
}
(*phi_kernel_)(phi_kernel_context_);
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <NvInfer.h>
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h"
#include "paddle/fluid/memory/allocation/cuda_allocator.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_context.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
void BuildPhiKernelContextAttr(const framework::OpDesc& op_desc,
phi::KernelContext* kernel_context,
const phi::KernelSignature& signature,
const phi::Kernel& phi_kernel);
class GenericPlugin : public DynamicPluginTensorRT {
public:
struct InputOutPutVarInfo {
std::vector<int> inputs_data_type;
std::vector<int> outputs_data_type;
};
public:
GenericPlugin() {}
GenericPlugin(const paddle::framework::proto::OpDesc& proto_op_desc,
const InputOutPutVarInfo& in_out_info);
GenericPlugin(const paddle::framework::proto::OpDesc& proto_op_desc,
const std::vector<int>& inputs_data_type,
const std::vector<int>& outputs_data_type);
// It was used for tensorrt deserialization.
// It should not be called by users.
GenericPlugin(void const* serialData, size_t serialLength);
// IPluginV2 method
const char* getPluginType() const TRT_NOEXCEPT override {
return "generic_plugin";
}
int getNbOutputs() const TRT_NOEXCEPT override;
int getNbInputs() const TRT_NOEXCEPT;
// Initialize the layer for execution.
int initialize() TRT_NOEXCEPT override;
// Shutdown the layer. This is called when the engine is destroyed
void terminate() TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT{};
size_t getSerializationSize() const TRT_NOEXCEPT {
return op_meta_data_.size() + SerializedSize(inputs_data_type_) +
SerializedSize(outputs_data_type_);
}
void serialize(void* buffer) const TRT_NOEXCEPT;
// The Func in IPluginV2
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT;
nvinfer1::DimsExprs getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) // NOLINT
TRT_NOEXCEPT;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nb_outputs) TRT_NOEXCEPT;
int enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT;
private:
std::string op_meta_data_;
framework::proto::OpDesc proto_op_desc_;
framework::OpDesc op_desc_;
private:
phi::KernelContext* phi_kernel_context_;
const phi::Kernel* phi_kernel_;
std::vector<phi::DenseTensor>* dense_tensor_inputs_;
std::vector<phi::DenseTensor>* dense_tensor_outputs_;
private:
InputOutPutVarInfo in_out_info_;
std::vector<int> inputs_data_type_;
std::vector<int> outputs_data_type_;
};
class GenericPluginCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const TRT_NOEXCEPT override {
return "generic_plugin";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
nvinfer1::IPluginV2DynamicExt* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length)
TRT_NOEXCEPT override {
return new GenericPlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(GenericPluginCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -125,10 +125,11 @@ class MishPluginDynamic : public DynamicPluginTensorRT {
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void* buffer) const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(int output_index,
nvinfer1::DimsExprs getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder)
nvinfer1::IExprBuilder& expr_builder) // NOLINT
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h"
namespace paddle {
namespace inference {
namespace tensorrt {
bool PluginArgumentMappingContext::HasInput(const std::string& name) const {
auto inputs = op_desc_ptr_->Inputs();
for (auto& i : inputs) {
if (i.first == name && !i.second.empty()) return true;
}
return false;
}
bool PluginArgumentMappingContext::HasOutput(const std::string& name) const {
auto outputs = op_desc_ptr_->Outputs();
for (auto& i : outputs) {
if (i.first == name && !i.second.empty()) return true;
}
return false;
}
bool PluginArgumentMappingContext::HasAttr(const std::string& name) const {
return op_desc_ptr_->HasAttr(name);
}
paddle::any PluginArgumentMappingContext::Attr(
const std::string& attr_name) const {
auto attr_type = op_desc_ptr_->GetAttrType(attr_name);
switch (attr_type) {
case framework::proto::AttrType::INT: {
return PADDLE_GET_CONST(int, op_desc_ptr_->GetAttr(attr_name));
break;
};
case framework::proto::AttrType::FLOAT: {
return PADDLE_GET_CONST(float, op_desc_ptr_->GetAttr(attr_name));
break;
};
case framework::proto::AttrType::STRING: {
return PADDLE_GET_CONST(std::string, op_desc_ptr_->GetAttr(attr_name));
break;
};
case framework::proto::AttrType::INTS: {
return PADDLE_GET_CONST(std::vector<int>,
op_desc_ptr_->GetAttr(attr_name));
break;
};
case framework::proto::AttrType::FLOATS: {
return PADDLE_GET_CONST(std::vector<float>,
op_desc_ptr_->GetAttr(attr_name));
break;
};
case framework::proto::AttrType::STRINGS: {
return PADDLE_GET_CONST(std::vector<std::string>,
op_desc_ptr_->GetAttr(attr_name));
break;
};
case framework::proto::AttrType::BOOLEAN: {
return PADDLE_GET_CONST(bool, op_desc_ptr_->GetAttr(attr_name));
break;
};
case framework::proto::AttrType::BOOLEANS: {
return PADDLE_GET_CONST(std::vector<bool>,
op_desc_ptr_->GetAttr(attr_name));
break;
};
default: {
LOG(ERROR) << "Can't conver op's attribute [" << attr_name
<< "] to paddle any.";
}
}
return paddle::any();
}
size_t PluginArgumentMappingContext::InputSize(const std::string& name) const {
return op_desc_ptr_->Inputs().at(name).size();
}
size_t PluginArgumentMappingContext::OutputSize(const std::string& name) const {
return op_desc_ptr_->Outputs().at(name).size();
}
bool PluginArgumentMappingContext::IsDenseTensorInput(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsDenseTensorInputs(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsSelectedRowsInput(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsDenseTensorVectorInput(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsDenseTensorOutput(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsSelectedRowsOutput(
const std::string& name) const {
return false;
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/phi/core/compat/arg_map_context.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext {
public:
explicit PluginArgumentMappingContext(framework::OpDesc* op_desc_ptr)
: op_desc_ptr_(op_desc_ptr) {}
bool HasInput(const std::string& name) const override;
bool HasOutput(const std::string& name) const override;
bool HasAttr(const std::string& name) const override;
paddle::any Attr(const std::string& attr_name) const override;
size_t InputSize(const std::string& name) const override;
size_t OutputSize(const std::string& name) const override;
bool IsDenseTensorInput(const std::string& name) const override;
bool IsDenseTensorInputs(const std::string& name) const override;
bool IsSelectedRowsInput(const std::string& name) const override;
bool IsDenseTensorVectorInput(const std::string& name) const override;
bool IsDenseTensorOutput(const std::string& name) const override;
bool IsSelectedRowsOutput(const std::string& name) const override;
bool IsForInferShape() const override { return false; }
private:
framework::OpDesc* op_desc_ptr_;
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(ArgMappingContexTest, BasicFunction) {
paddle::framework::proto::OpDesc op;
op.set_type("imaged_op");
auto *input_var = op.add_inputs();
input_var->set_parameter("X");
*input_var->add_arguments() = "input";
auto *output_var = op.add_outputs();
output_var->set_parameter("Out");
*output_var->add_arguments() = "output";
auto *attr = op.add_attrs();
attr->set_name("int_attr");
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(1);
attr = op.add_attrs();
attr->set_name("float_attr");
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(1.0);
attr = op.add_attrs();
attr->set_name("string_attr");
attr->set_type(paddle::framework::proto::AttrType::STRING);
attr->set_s("1");
attr = op.add_attrs();
attr->set_name("bool_attr");
attr->set_type(paddle::framework::proto::AttrType::BOOLEAN);
attr->set_b(true);
attr = op.add_attrs();
attr->set_name("ints_attr");
attr->set_type(paddle::framework::proto::AttrType::INTS);
attr->add_ints(1);
attr->add_ints(2);
attr = op.add_attrs();
attr->set_name("floats_attr");
attr->set_type(paddle::framework::proto::AttrType::FLOATS);
attr->add_floats(1.0);
attr->add_floats(2.0);
attr = op.add_attrs();
attr->set_name("strings_attr");
attr->set_type(paddle::framework::proto::AttrType::STRINGS);
attr->add_strings("1");
attr->add_strings("2");
attr = op.add_attrs();
attr->set_name("bools_attr");
attr->set_type(paddle::framework::proto::AttrType::BOOLEANS);
attr->add_bools(true);
attr->add_bools(true);
framework::OpDesc op_desc(op, nullptr);
PluginArgumentMappingContext context(&op_desc);
EXPECT_EQ(context.HasInput("X"), true);
EXPECT_EQ(context.HasOutput("Out"), true);
EXPECT_EQ(context.HasAttr("int_attr"), true);
int int_attr = any_cast<int>(context.Attr("int_attr"));
EXPECT_EQ(int_attr, 1);
float flaot_attr = any_cast<float>(context.Attr("float_attr"));
EXPECT_EQ(flaot_attr, 1);
std::string string_attr = any_cast<std::string>(context.Attr("string_attr"));
EXPECT_EQ(string_attr, "1");
bool bool_attr = any_cast<bool>(context.Attr("bool_attr"));
EXPECT_EQ(bool_attr, true);
std::vector<int> ints_attr =
any_cast<std::vector<int>>(context.Attr("ints_attr"));
EXPECT_EQ(ints_attr[0], 1);
EXPECT_EQ(ints_attr[1], 2);
std::vector<float> floats_attr =
any_cast<std::vector<float>>(context.Attr("floats_attr"));
EXPECT_EQ(floats_attr[0], 1.0);
EXPECT_EQ(floats_attr[1], 2.0);
std::vector<std::string> strings_attr =
any_cast<std::vector<std::string>>(context.Attr("strings_attr"));
EXPECT_EQ(strings_attr[0], "1");
EXPECT_EQ(strings_attr[1], "2");
std::vector<bool> bools_attr =
any_cast<std::vector<bool>>(context.Attr("bools_attr"));
EXPECT_EQ(bools_attr[0], true);
EXPECT_EQ(bools_attr[1], true);
EXPECT_EQ(context.InputSize("X"), true);
EXPECT_EQ(context.OutputSize("Out"), true);
EXPECT_EQ(context.IsDenseTensorInput("X"), false);
EXPECT_EQ(context.IsDenseTensorInputs("X"), false);
EXPECT_EQ(context.IsSelectedRowsInput("X"), false);
EXPECT_EQ(context.IsDenseTensorVectorInput("X"), false);
EXPECT_EQ(context.IsDenseTensorOutput("Out"), false);
EXPECT_EQ(context.IsSelectedRowsOutput("Out"), false);
EXPECT_EQ(context.IsForInferShape(), false);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -159,6 +159,8 @@ void DynamicShapeTest(bool allow_build_at_runtime) {
// Execute them.
LOG(INFO) << "engine_op run";
inference::tensorrt::OpTeller::Global().SetOpConverterType(
"fc", inference::tensorrt::OpConverterType::Default);
engine_op->Run(scope, place);
}
......
......@@ -19,11 +19,15 @@ import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import os
class TrtConvertGatherNdTest_dim_4_1(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
# The output has diff between gpu and trt in CI windows
# if ( and self.trt_param.precision == paddle_infer.PrecisionType.Half):
# return False
return True
def sample_program_configs(self):
......@@ -46,13 +50,15 @@ class TrtConvertGatherNdTest_dim_4_1(TrtLayerAutoScanTest):
"op_attrs": {}
}]
ops = self.generate_op_config(ops_config)
for i in range(10):
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(data_gen=partial(generate_input1)),
"index_data": TensorConfig(data_gen=partial(generate_input2)),
"input_data":
TensorConfig(data_gen=partial(generate_input1)),
"index_data":
TensorConfig(data_gen=partial(generate_input2)),
},
outputs=["output_data"])
......@@ -71,7 +77,7 @@ class TrtConvertGatherNdTest_dim_4_1(TrtLayerAutoScanTest):
"index_data": [1]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2, 4, 64, 64],
"input_data": [2, 32, 64, 64],
"index_data": [1]
}
......@@ -94,11 +100,23 @@ class TrtConvertGatherNdTest_dim_4_1(TrtLayerAutoScanTest):
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (0, 4), 1e-5
yield self.create_inference_config(), (1, 3), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (0, 4), 1e-5
yield self.create_inference_config(), (1, 3), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt':
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_SUPPORT,
"Under Windows Ci, this case will sporadically fail.")
def test(self):
self.add_skip_trt_case()
self.run_test()
......@@ -145,14 +163,14 @@ class TrtConvertGatherNdTest_dim_4_1_2(TrtLayerAutoScanTest):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"input_data": [1, 8, 8, 8],
"index_data": [1]
"index_data": [2]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 32, 64, 64],
"index_data": [4]
"index_data": [2]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2, 4, 64, 64],
"input_data": [2, 32, 64, 64],
"index_data": [2]
}
......@@ -175,11 +193,23 @@ class TrtConvertGatherNdTest_dim_4_1_2(TrtLayerAutoScanTest):
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (0, 4), 1e-5
yield self.create_inference_config(), (1, 3), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (0, 4), 1e-5
yield self.create_inference_config(), (1, 3), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt':
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_SUPPORT,
"Under Windows Ci, this case will sporadically fail.")
def test(self):
self.add_skip_trt_case()
self.run_test()
......@@ -226,14 +256,14 @@ class TrtConvertGatherNdTest_dim_4_2(TrtLayerAutoScanTest):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"input_data": [1, 8, 8, 8],
"index_data": [1, 2]
"index_data": [2, 2]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 32, 64, 64],
"index_data": [4, 4]
"index_data": [2, 2]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2, 4, 64, 64],
"input_data": [2, 32, 64, 64],
"index_data": [2, 2]
}
......@@ -256,11 +286,23 @@ class TrtConvertGatherNdTest_dim_4_2(TrtLayerAutoScanTest):
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (0, 4), 1e-5
yield self.create_inference_config(), (1, 3), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (0, 4), 1e-5
yield self.create_inference_config(), (1, 3), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt':
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_SUPPORT,
"Under Windows Ci, this case will sporadically fail.")
def test(self):
self.add_skip_trt_case()
self.run_test()
......@@ -307,15 +349,15 @@ class TrtConvertGatherNdTest_dim_4_3(TrtLayerAutoScanTest):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"input_data": [1, 8, 8, 8],
"index_data": [1, 2, 2]
"index_data": [2, 2, 4]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 32, 64, 64],
"index_data": [4, 4, 4]
"index_data": [2, 2, 4]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2, 4, 64, 64],
"index_data": [2, 2, 2]
"input_data": [2, 32, 64, 64],
"index_data": [2, 2, 4]
}
def clear_dynamic_shape():
......@@ -337,11 +379,23 @@ class TrtConvertGatherNdTest_dim_4_3(TrtLayerAutoScanTest):
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (0, 4), 1e-5
yield self.create_inference_config(), (1, 3), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (0, 4), 1e-5
yield self.create_inference_config(), (1, 3), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt':
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_SUPPORT,
"Under Windows Ci, this case will sporadically fail.")
def test(self):
self.add_skip_trt_case()
self.run_test()
......@@ -388,11 +442,11 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"input_data": [1, 4],
"index_data": [1, 1]
"index_data": [2, 2]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 64],
"index_data": [4, 2]
"index_data": [2, 2]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2, 8],
......@@ -418,11 +472,23 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest):
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (0, 4), 1e-5
yield self.create_inference_config(), (1, 3), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (0, 4), 1e-5
yield self.create_inference_config(), (1, 3), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt':
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_SUPPORT,
"Under Windows Ci, this case will sporadically fail.")
def test(self):
self.add_skip_trt_case()
self.run_test()
......
......@@ -107,24 +107,30 @@ class TrtConvertYoloBoxTest(TrtLayerAutoScanTest):
if attrs[0]['iou_aware'] == True:
channel = 3 * (attrs[0]['class_num'] + 6)
self.dynamic_shape.min_input_shape = {
"scale_input": [1, channel, 12, 12]
"yolo_box_input": [1, channel, 12, 12],
"imgsize": [1, 2]
}
self.dynamic_shape.max_input_shape = {
"scale_input": [4, channel, 24, 24]
"yolo_box_input": [4, channel, 24, 24],
"imgsize": [4, 2]
}
self.dynamic_shape.opt_input_shape = {
"scale_input": [1, channel, 24, 24]
"yolo_box_input": [1, channel, 24, 24],
"imgsize": [1, 2]
}
else:
channel = 3 * (attrs[0]['class_num'] + 5)
self.dynamic_shape.min_input_shape = {
"scale_input": [1, channel, 12, 12]
"yolo_box_input": [1, channel, 12, 12],
"imgsize": [1, 2]
}
self.dynamic_shape.max_input_shape = {
"scale_input": [4, channel, 24, 24]
"yolo_box_input": [4, channel, 24, 24],
"imgsize": [4, 2]
}
self.dynamic_shape.opt_input_shape = {
"scale_input": [1, channel, 24, 24]
"yolo_box_input": [1, channel, 24, 24],
"imgsize": [1, 2]
}
def clear_dynamic_shape():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册