未验证 提交 394f92aa 编写于 作者: Z zlsh80826 提交者: GitHub

[Paddle-TRT] IPluginExt -> IPluginV2 (#33680)

* add trt LT version helper

* upgrade PluginTensorRT to IPluginV2Ext

* trt plugin factory is not usable in IPluginV2

* upgrade add plugin api to use IPluginV2

* remove IPlugin register and adapt getSerializeSize(), serialize()

* adapt IPluginV2Layer

* downgrade to IPluginV2

* implement elementwise clone

* add gelu plugin creator and fix gelu serialization bug

* add swish plugin creator and fix swish serialization bug

* format

* fix typo

* add elementwise plugin creator and fix serialization

* add base creator class

* add gelu plugin creator

* add hard swish creator and fix serialization

* add instance norm creator and fix serialization

* add layer norm creator and fix serialization

* add pool creator and fix serialization

* add prelu creator and fix serialization

* add slice creator and fix serialization

* add swish creator and fix serialization

* add instance norm op unittest

* remove redundent api

* fix wrong graph size to enable trt

* instance norm function move to cc

* add trt elementwise ut to trigger coverage

* remove opt cahce to hit serialization coverage

* remove opt cahce to hit serialization coverage

* remove unused code

* remove unused inputs_

* add dbg info

* remove dbg info

* add instance norm serialization

* roll back

* remove comment code

* remove trt plugin registery

* fix prelu dynamic serialization

* add prelu ut and reduce the input size to reduce memory usage

* fix pool dynamic plugin serialization and add ut

* refine pool ut with subtest

* add env for avoiding oom

* reduce test input size & increase pool op ut to 45s

* add the contributor

* remove copyright (will add in contributor)

* remove copyright (will add in contributor)
上级 0b20b76e
......@@ -251,10 +251,10 @@ class ElementwiseTensorOpConverter : public OpConverter {
} else {
plugin::ElementWisePlugin* plugin =
new plugin::ElementWisePlugin(op_type_, dims_x, dims_y, axis);
plugin->AddInput(X);
plugin->AddInput(Y);
nvinfer1::IPluginLayer* plugin_layer = engine_->AddPlugin(
plugin->GetInputs().data(), 2,
std::vector<nvinfer1::ITensor*> inputs{X, Y};
auto* plugin_layer = engine_->AddPlugin(
inputs.data(), inputs.size(),
reinterpret_cast<plugin::PluginTensorRT*>(plugin));
layer = plugin_layer;
......
......@@ -74,7 +74,7 @@ class InstanceNormOpConverter : public OpConverter {
plugin::InstanceNormPlugin* plugin =
new plugin::InstanceNormPlugin(eps, scale_v, bias_v);
plugin->getPluginType();
nvinfer1::IPluginLayer* layer = engine_->AddPlugin(&input, 1, plugin);
auto* layer = engine_->AddPlugin(&input, 1, plugin);
auto output_name = op_desc.Output("Y")[0];
RreplenishLayerAndOutput(layer, "instance_norm", {output_name}, test_mode);
......
......@@ -61,7 +61,8 @@ class ShuffleChannelOpConverter : public OpConverter {
reshape_layer->setReshapeDimensions(reshape_dim2);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(reshape_layer, "concat", {output_name}, test_mode);
RreplenishLayerAndOutput(reshape_layer, "shuffle_channel", {output_name},
test_mode);
}
};
......
......@@ -330,11 +330,11 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,
int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; }
nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
nvinfer1::IPluginV2Layer *TensorRTEngine::AddPlugin(
nvinfer1::ITensor *const *inputs, int num_inputs,
plugin::PluginTensorRT *plugin) {
owned_plugin_.emplace_back(plugin);
return network()->addPluginExt(inputs, num_inputs, *plugin);
return network()->addPluginV2(inputs, num_inputs, *plugin);
}
nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2Ext(
......
......@@ -30,7 +30,6 @@ limitations under the License. */
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "paddle/fluid/inference/utils/singleton.h"
......@@ -276,19 +275,8 @@ class TensorRTEngine {
}
}
if (with_dynamic_shape_) {
infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data.c_str(), engine_serialized_data.size()));
} else {
#if IS_TRT_VERSION_LT(8000)
infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data.c_str(), engine_serialized_data.size(),
&inference::Singleton<plugin::PluginFactoryTensorRT>::Global()));
#else
infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data.c_str(), engine_serialized_data.size()));
#endif
}
infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data.c_str(), engine_serialized_data.size()));
PADDLE_ENFORCE_NOT_NULL(
infer_engine_,
......@@ -311,8 +299,8 @@ class TensorRTEngine {
int GetDeviceId() { return device_id_; }
nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs,
int num_inputs, plugin::PluginTensorRT*);
nvinfer1::IPluginV2Layer* AddPlugin(nvinfer1::ITensor* const* inputs,
int num_inputs, plugin::PluginTensorRT*);
nvinfer1::IPluginV2Layer* AddPluginV2Ext(nvinfer1::ITensor* const* inputs,
int num_inputs,
......
nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
prelu_op_plugin.cu gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
......
......@@ -18,8 +18,6 @@
#include <cassert>
#include "paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/detection/anchor_generator_op.h"
namespace paddle {
......
......@@ -14,19 +14,12 @@ limitations under the License. */
#include <glog/logging.h>
#include "paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
ElementWisePlugin *CreateElementWisePluginDeserialize(const void *buffer,
size_t length) {
return new ElementWisePlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("elementwise_plugin", CreateElementWisePluginDeserialize);
namespace details {
template <typename T>
struct Add {
......
......@@ -40,14 +40,16 @@ class ElementWisePlugin : public PluginTensorRT {
const char* elementwise_type;
DeserializeValue(&serial_data, &serial_length, &elementwise_type);
type_ = std::string(elementwise_type);
DeserializeValue(&serial_data, &serial_length, &axis_);
DeserializeValue(&serial_data, &serial_length, &dims_x_);
DeserializeValue(&serial_data, &serial_length, &dims_y_);
DeserializeValue(&serial_data, &serial_length, &axis_);
DeserializeValue(&serial_data, &serial_length, &prev_size_);
DeserializeValue(&serial_data, &serial_length, &midd_size_);
DeserializeValue(&serial_data, &serial_length, &post_size_);
}
ElementWisePlugin* clone() const override {
// return new ElementWisePlugin(dims_x_, dims_y_, axis_);
return nullptr;
return new ElementWisePlugin(type_, dims_x_, dims_y_, axis_);
}
const char* getPluginType() const override { return "elementwise_plugin"; }
......@@ -65,22 +67,25 @@ class ElementWisePlugin : public PluginTensorRT {
#endif
void* workspace, cudaStream_t stream);
protected:
size_t getSerializationSize() override {
return SerializedSize(getPluginType()) + SerializedSize(axis_) +
size_t getSerializationSize() const override {
return getBaseSerializationSize() + SerializedSize(type_.c_str()) +
SerializedSize(dims_x_) + SerializedSize(dims_y_) +
getBaseSerializationSize();
SerializedSize(axis_) + SerializedSize(prev_size_) +
SerializedSize(midd_size_) + SerializedSize(post_size_);
}
void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType());
void serialize(void* buffer) const override {
serializeBase(buffer);
SerializeValue(&buffer, type_.c_str());
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, dims_x_);
SerializeValue(&buffer, dims_y_);
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, prev_size_);
SerializeValue(&buffer, midd_size_);
SerializeValue(&buffer, post_size_);
}
protected:
std::string type_;
nvinfer1::Dims dims_x_;
nvinfer1::Dims dims_y_;
......@@ -90,6 +95,20 @@ class ElementWisePlugin : public PluginTensorRT {
int post_size_;
};
class ElementWisePluginCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const override { return "elementwise_plugin"; }
const char* getPluginVersion() const override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
return new ElementWisePlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(ElementWisePluginCreator);
#if IS_TRT_VERSION_GE(6000)
class ElementwisePluginDynamic : public DynamicPluginTensorRT {
public:
......@@ -105,7 +124,9 @@ class ElementwisePluginDynamic : public DynamicPluginTensorRT {
return new ElementwisePluginDynamic(type_, axis_);
}
const char* getPluginType() const override { return "elementwise_plugin"; }
const char* getPluginType() const override {
return "elementwise_plugin_dynamic";
}
int getNbOutputs() const override { return 1; }
int initialize() override;
......@@ -150,7 +171,9 @@ class ElementwisePluginDynamic : public DynamicPluginTensorRT {
class ElementwisePluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
ElementwisePluginDynamicCreator() {}
const char* getPluginName() const override { return "elementwise_plugin"; }
const char* getPluginName() const override {
return "elementwise_plugin_dynamic";
}
const char* getPluginVersion() const override { return "1"; }
......
......@@ -20,7 +20,6 @@
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
namespace paddle {
......
......@@ -22,7 +22,6 @@
#include "NvInferRuntimeCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/gather_nd_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
......
......@@ -16,7 +16,6 @@
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -31,12 +30,6 @@ static const float kAT = 0.5;
static const float kBT = 0.7978845608028654; // sqrt(2.0/M_PI)
static const float kCT = 0.035677408136300125; // 0.044715 * sqrt(2.0/M_PI)
GeluPlugin* CreateGeluPluginDeserialize(const void* buffer, size_t length) {
return new GeluPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("gelu_plugin", CreateGeluPluginDeserialize);
bool GeluPlugin::supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const {
if (with_fp16_) {
......
......@@ -51,18 +51,28 @@ class GeluPlugin : public PluginTensorRT {
#endif
void* workspace, cudaStream_t stream) override;
protected:
size_t getSerializationSize() override {
return getBaseSerializationSize() + SerializedSize(getPluginType());
size_t getSerializationSize() const override {
return getBaseSerializationSize();
}
// TRT will call this func to serialize the configuration of TRT
// It should not be called by users.
void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer);
void serialize(void* buffer) const override { serializeBase(buffer); }
};
class GeluPluginCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const override { return "gelu_plugin"; }
const char* getPluginVersion() const override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
return new GeluPlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(GeluPluginCreator);
#if IS_TRT_VERSION_GE(6000)
class GeluPluginDynamic : public DynamicPluginTensorRT {
......@@ -77,7 +87,7 @@ class GeluPluginDynamic : public DynamicPluginTensorRT {
return new GeluPluginDynamic(with_fp16_);
}
const char* getPluginType() const override { return "gelu_plugin"; }
const char* getPluginType() const override { return "gelu_plugin_dynamic"; }
int getNbOutputs() const override { return 1; }
int initialize() override { return 0; }
......@@ -119,44 +129,19 @@ class GeluPluginDynamic : public DynamicPluginTensorRT {
void destroy() override { delete this; }
};
class GeluPluginDynamicCreator : public nvinfer1::IPluginCreator {
class GeluPluginDynamicCreator : public TensorRTPluginCreator {
public:
GeluPluginDynamicCreator() {}
const char* getPluginName() const override { return "gelu_plugin"; }
const char* getPluginName() const override { return "gelu_plugin_dynamic"; }
const char* getPluginVersion() const override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
auto plugin = new GeluPluginDynamic(serial_data, serial_length);
return plugin;
}
void setPluginNamespace(const char* lib_namespace) override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(GeluPluginDynamicCreator);
#endif
......
......@@ -15,20 +15,12 @@
#include <cassert>
#include <cstring>
#include "paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
HardSwishPlugin* CreateHardSwishPluginDeserialize(const void* buffer,
size_t length) {
return new HardSwishPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("hard_swish_plugin", CreateHardSwishPluginDeserialize);
nvinfer1::Dims HardSwishPlugin::getOutputDimensions(
int index, const nvinfer1::Dims* in_dims, int nb_inputs) {
assert(nb_inputs == 1);
......
......@@ -56,27 +56,39 @@ class HardSwishPlugin : public PluginTensorRT {
#endif
void* workspace, cudaStream_t stream) override;
protected:
float threshold_;
float scale_;
float offset_;
size_t getSerializationSize() override {
size_t getSerializationSize() const override {
return getBaseSerializationSize() + SerializedSize(threshold_) +
SerializedSize(scale_) + SerializedSize(offset_) +
SerializedSize(getPluginType());
SerializedSize(scale_) + SerializedSize(offset_);
}
// TRT will call this func to serialize the configuration of TRT
// It should not be called by users.
void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType());
void serialize(void* buffer) const override {
serializeBase(buffer);
SerializeValue(&buffer, threshold_);
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, offset_);
}
protected:
float threshold_;
float scale_;
float offset_;
};
class HardSwishPluginCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const override { return "hard_swish_plugin"; }
const char* getPluginVersion() const override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
return new HardSwishPlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(HardSwishPluginCreator);
} // namespace plugin
} // namespace tensorrt
......
......@@ -17,7 +17,6 @@
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
......@@ -40,13 +39,6 @@ cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype,
return CUDNN_STATUS_SUCCESS;
}
InstanceNormPlugin *CreateInstanceNormPluginDeserialize(const void *buffer,
size_t length) {
return new InstanceNormPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("instance_norm_plugin",
CreateInstanceNormPluginDeserialize);
int InstanceNormPlugin::initialize() { return 0; }
nvinfer1::Dims InstanceNormPlugin::getOutputDimensions(
......@@ -58,6 +50,13 @@ nvinfer1::Dims InstanceNormPlugin::getOutputDimensions(
return output_dims;
}
bool InstanceNormPlugin::supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const {
return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kLINEAR));
}
int InstanceNormPlugin::enqueue(int batch_size, const void *const *inputs,
#if IS_TRT_VERSION_LT(8000)
void **outputs, void *workspace,
......
......@@ -38,25 +38,22 @@ class InstanceNormPlugin : public PluginTensorRT {
cudnnHandle_t handle_;
cudnnTensorDescriptor_t x_desc_, y_desc_, b_desc_;
protected:
size_t getSerializationSize() override {
public:
size_t getSerializationSize() const override {
return getBaseSerializationSize() + SerializedSize(eps_) +
SerializedSize(scale_) + SerializedSize(bias_) +
SerializedSize(getPluginType());
SerializedSize(scale_) + SerializedSize(bias_);
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void serialize(void *buffer) override {
SerializeValue(&buffer, getPluginType());
void serialize(void *buffer) const override {
serializeBase(buffer);
SerializeValue(&buffer, eps_);
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, bias_);
}
public:
explicit InstanceNormPlugin(const float eps, const std::vector<float> scale,
const std::vector<float> bias)
: eps_(eps), scale_(scale), bias_(bias) {
......@@ -91,6 +88,7 @@ class InstanceNormPlugin : public PluginTensorRT {
platform::dynload::cudnnDestroyTensorDescriptor(y_desc_);
platform::dynload::cudnnDestroyTensorDescriptor(b_desc_);
}
int initialize() override;
InstanceNormPlugin *clone() const override {
......@@ -101,6 +99,7 @@ class InstanceNormPlugin : public PluginTensorRT {
int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
int nbInputDims) override;
#if IS_TRT_VERSION_LT(8000)
int enqueue(int batchSize, const void *const *inputs, void **outputs,
#else
......@@ -109,12 +108,22 @@ class InstanceNormPlugin : public PluginTensorRT {
void *workspace, cudaStream_t stream) override;
bool supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const override {
return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kLINEAR));
nvinfer1::PluginFormat format) const override;
};
class InstanceNormPluginCreator : public TensorRTPluginCreator {
public:
const char *getPluginName() const override { return "instance_norm_plugin"; }
const char *getPluginVersion() const override { return "1"; }
nvinfer1::IPluginV2 *deserializePlugin(const char *name,
const void *serial_data,
size_t serial_length) override {
return new InstanceNormPlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(InstanceNormPluginCreator);
} // namespace plugin
} // namespace tensorrt
......
......@@ -17,7 +17,6 @@
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/layer_norm_op.h"
namespace paddle {
......@@ -25,12 +24,6 @@ namespace inference {
namespace tensorrt {
namespace plugin {
LayerNormPlugin *CreateLayerNormPluginDeserialize(const void *buffer,
size_t length) {
return new LayerNormPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("layer_norm_plugin", CreateLayerNormPluginDeserialize);
int LayerNormPlugin::initialize() { return 0; }
nvinfer1::Dims LayerNormPlugin::getOutputDimensions(
......
......@@ -39,19 +39,18 @@ class LayerNormPlugin : public PluginTensorRT {
std::vector<int64_t> mean_shape_;
std::vector<int64_t> variance_shape_;
protected:
size_t getSerializationSize() override {
public:
size_t getSerializationSize() const override {
return getBaseSerializationSize() + SerializedSize(bias_) +
SerializedSize(scale_) + SerializedSize(begin_norm_axis_) +
SerializedSize(eps_) + SerializedSize(mean_shape_) +
SerializedSize(variance_shape_) + SerializedSize(getPluginType());
SerializedSize(variance_shape_);
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType());
void serialize(void* buffer) const override {
serializeBase(buffer);
SerializeValue(&buffer, bias_);
SerializeValue(&buffer, scale_);
......@@ -61,7 +60,6 @@ class LayerNormPlugin : public PluginTensorRT {
SerializeValue(&buffer, variance_shape_);
}
public:
LayerNormPlugin(const float* bias, const int bias_num, const float* scale,
const int scale_num, int begin_norm_axis, float eps,
std::vector<int64_t> mean_shape,
......@@ -96,7 +94,7 @@ class LayerNormPlugin : public PluginTensorRT {
mean_shape_, variance_shape_);
}
const char* getPluginType() const override { return "layer_norm_plugin"; }
const char* getPluginType() const override { return "layernorm_plugin"; }
int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
int nbInputDims) override;
......@@ -108,6 +106,20 @@ class LayerNormPlugin : public PluginTensorRT {
void* workspace, cudaStream_t stream) override;
};
class LayerNormPluginCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const override { return "layernorm_plugin"; }
const char* getPluginVersion() const override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
return new LayerNormPlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(LayerNormPluginCreator);
class LayerNormPluginDynamic : public DynamicPluginTensorRT {
public:
LayerNormPluginDynamic(const float* bias, const int bias_num,
......@@ -139,7 +151,9 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT {
mean_shape_, variance_shape_);
}
const char* getPluginType() const override { return "layernorm_plugin"; }
const char* getPluginType() const override {
return "layernorm_plugin_dynamic";
}
int getNbOutputs() const override { return 1; }
int initialize() override { return 0; }
......@@ -201,42 +215,19 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT {
std::vector<int64_t> variance_shape_;
};
class LayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator {
class LayerNormPluginDynamicCreator : public TensorRTPluginCreator {
public:
LayerNormPluginDynamicCreator() {}
const char* getPluginName() const override { return "layernorm_plugin"; }
const char* getPluginVersion() const override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() override {
return &field_collection_;
const char* getPluginName() const override {
return "layernorm_plugin_dynamic";
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
return nullptr;
}
const char* getPluginVersion() const override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
auto plugin = new LayerNormPluginDynamic(serial_data, serial_length);
return plugin;
}
void setPluginNamespace(const char* lib_namespace) override {
plugin_namespace_ = lib_namespace;
return new LayerNormPluginDynamic(serial_data, serial_length);
}
const char* getPluginNamespace() const override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(LayerNormPluginDynamicCreator);
......
......@@ -13,7 +13,6 @@
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/math/pooling.h"
namespace paddle {
......@@ -21,11 +20,6 @@ namespace inference {
namespace tensorrt {
namespace plugin {
PoolPlugin *CreatePoolPluginDeserialize(const void *buffer, size_t length) {
return new PoolPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("pool_plugin", CreatePoolPluginDeserialize);
nvinfer1::Dims PoolPlugin::getOutputDimensions(int index,
const nvinfer1::Dims *inputDims,
int nbInputs) {
......@@ -80,9 +74,35 @@ int PoolPlugin::enqueue(int batchSize, const void *const *inputs,
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
size_t PoolPluginDynamic::getSerializationSize() const { return 0; }
PoolPluginDynamic::PoolPluginDynamic(void const *serialData,
size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &ceil_mode_);
const char *pool_type;
DeserializeValue(&serialData, &serialLength, &pool_type);
pool_type_ = std::string(pool_type);
DeserializeValue(&serialData, &serialLength, &adaptive_);
DeserializeValue(&serialData, &serialLength, &ksize_);
DeserializeValue(&serialData, &serialLength, &strides_);
DeserializeValue(&serialData, &serialLength, &paddings_);
DeserializeValue(&serialData, &serialLength, &is_global_);
}
size_t PoolPluginDynamic::getSerializationSize() const {
return SerializedSize(ceil_mode_) + SerializedSize(pool_type_.c_str()) +
SerializedSize(adaptive_) + SerializedSize(ksize_) +
SerializedSize(strides_) + SerializedSize(paddings_) +
SerializedSize(is_global_);
}
void PoolPluginDynamic::serialize(void *buffer) const {}
void PoolPluginDynamic::serialize(void *buffer) const {
SerializeValue(&buffer, ceil_mode_);
SerializeValue(&buffer, pool_type_.c_str());
SerializeValue(&buffer, adaptive_);
SerializeValue(&buffer, ksize_);
SerializeValue(&buffer, strides_);
SerializeValue(&buffer, paddings_);
SerializeValue(&buffer, is_global_);
}
nvinfer1::DimsExprs PoolPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
......
......@@ -56,19 +56,18 @@ static std::vector<int> CalcOutputSize(const std::vector<int>& input_shape,
}
class PoolPlugin : public PluginTensorRT {
protected:
size_t getSerializationSize() override {
return SerializedSize(getPluginType()) + SerializedSize(ceil_mode_) +
public:
size_t getSerializationSize() const override {
return getBaseSerializationSize() + SerializedSize(ceil_mode_) +
SerializedSize(pool_type_) + SerializedSize(adaptive_) +
SerializedSize(ksize_) + SerializedSize(strides_) +
SerializedSize(paddings_) + SerializedSize(input_shape_) +
SerializedSize(output_shape_) + getBaseSerializationSize();
SerializedSize(output_shape_);
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType());
void serialize(void* buffer) const override {
serializeBase(buffer);
SerializeValue(&buffer, ceil_mode_);
SerializeValue(&buffer, pool_type_);
......@@ -80,7 +79,6 @@ class PoolPlugin : public PluginTensorRT {
SerializeValue(&buffer, output_shape_);
}
public:
enum class PoolType {
max = 0,
avg,
......@@ -146,6 +144,20 @@ class PoolPlugin : public PluginTensorRT {
std::vector<int> output_shape_;
};
class PoolPluginCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const override { return "pool_plugin"; }
const char* getPluginVersion() const override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
return new PoolPlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(PoolPluginCreator);
#if IS_TRT_VERSION_GE(6000)
class PoolPluginDynamic : public DynamicPluginTensorRT {
public:
......@@ -162,25 +174,14 @@ class PoolPluginDynamic : public DynamicPluginTensorRT {
paddings_(paddings),
is_global_(is_global) {}
PoolPluginDynamic(void const* serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &ceil_mode_);
const char* pool_type;
DeserializeValue(&serialData, &serialLength, &pool_type);
pool_type_ = std::string(pool_type);
DeserializeValue(&serialData, &serialLength, &adaptive_);
DeserializeValue(&serialData, &serialLength, &ksize_);
DeserializeValue(&serialData, &serialLength, &strides_);
DeserializeValue(&serialData, &serialLength, &paddings_);
DeserializeValue(&serialData, &serialLength, &is_global_);
}
PoolPluginDynamic(void const* serialData, size_t serialLength);
~PoolPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new PoolPluginDynamic(ceil_mode_, pool_type_, adaptive_, ksize_,
strides_, paddings_, is_global_);
}
const char* getPluginType() const override { return "pool_plugin"; }
const char* getPluginType() const override { return "pool_plugin_dynamic"; }
int getNbOutputs() const override { return 1; }
int initialize() override { return 0; }
......@@ -226,6 +227,20 @@ class PoolPluginDynamic : public DynamicPluginTensorRT {
std::vector<int> paddings_;
bool is_global_;
};
class PoolPluginDynamicCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const override { return "pool_plugin_dynamic"; }
const char* getPluginVersion() const override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
return new PoolPluginDynamic(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(PoolPluginDynamicCreator);
#endif
} // namespace plugin
......
......@@ -19,7 +19,6 @@
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/math/prelu.h"
namespace paddle {
......@@ -27,11 +26,6 @@ namespace inference {
namespace tensorrt {
namespace plugin {
PReluPlugin *CreatePreluPluginDeserialize(const void *buffer, size_t length) {
return new PReluPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("prelu_plugin", CreatePreluPluginDeserialize);
int PReluPlugin::initialize() {
cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size());
cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float),
......@@ -104,9 +98,23 @@ int PReluPluginDynamic::initialize() {
cudaMemcpyHostToDevice);
return 0;
}
size_t PReluPluginDynamic::getSerializationSize() const { return 0; }
void PReluPluginDynamic::serialize(void *buffer) const {}
PReluPluginDynamic::PReluPluginDynamic(void const *serialData,
size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &weight_);
const char *prelu_mode;
DeserializeValue(&serialData, &serialLength, &prelu_mode);
mode_ = std::string(prelu_mode);
}
size_t PReluPluginDynamic::getSerializationSize() const {
return SerializedSize(mode_.c_str()) + SerializedSize(weight_);
}
void PReluPluginDynamic::serialize(void *buffer) const {
SerializeValue(&buffer, weight_);
SerializeValue(&buffer, mode_.c_str());
}
nvinfer1::DimsExprs PReluPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
......
......@@ -33,23 +33,21 @@ class PReluPlugin : public PluginTensorRT {
float* p_gpu_weight_;
std::string mode_;
protected:
size_t getSerializationSize() override {
public:
size_t getSerializationSize() const override {
return getBaseSerializationSize() + SerializedSize(mode_.c_str()) +
SerializedSize(weight_) + SerializedSize(getPluginType());
SerializedSize(weight_);
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType());
void serialize(void* buffer) const override {
serializeBase(buffer);
SerializeValue(&buffer, weight_);
SerializeValue(&buffer, mode_.c_str());
}
public:
PReluPlugin(const float* weight, const int weight_num,
std::string const& mode)
: mode_(mode) {
......@@ -88,6 +86,20 @@ class PReluPlugin : public PluginTensorRT {
void* workspace, cudaStream_t stream) override;
};
class PReluPluginCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const override { return "prelu_plugin"; }
const char* getPluginVersion() const override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
return new PReluPlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(PReluPluginCreator);
#if IS_TRT_VERSION_GE(6000)
class PReluPluginDynamic : public DynamicPluginTensorRT {
public:
......@@ -98,15 +110,7 @@ class PReluPluginDynamic : public DynamicPluginTensorRT {
std::copy(weight, weight + weight_num, weight_.data());
}
// It was used for tensorrt deserialization.
// It should not be called by users.
PReluPluginDynamic(void const* serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &weight_);
const char* prelu_mode;
DeserializeValue(&serialData, &serialLength, &prelu_mode);
mode_ = std::string(prelu_mode);
}
PReluPluginDynamic(void const* serialData, size_t serialLength);
~PReluPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* clone() const override {
auto ptr = new PReluPluginDynamic(weight_.data(), weight_.size(), mode_);
......@@ -114,7 +118,7 @@ class PReluPluginDynamic : public DynamicPluginTensorRT {
return ptr;
}
const char* getPluginType() const override { return "prelu_plugin"; }
const char* getPluginType() const override { return "prelu_plugin_dynamic"; }
int getNbOutputs() const override { return 1; }
int initialize() override;
void terminate() override;
......@@ -159,6 +163,20 @@ class PReluPluginDynamic : public DynamicPluginTensorRT {
};
#endif
class PReluPluginDynamicCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const override { return "prelu_plugin_dynamic"; }
const char* getPluginVersion() const override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
return new PReluPluginDynamic(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(PReluPluginDynamicCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
......
......@@ -20,7 +20,6 @@
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/operators/math/blas.h"
......
......@@ -17,7 +17,6 @@
#include <algorithm>
#include "paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
......
......@@ -19,7 +19,6 @@
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
namespace paddle {
......
......@@ -19,18 +19,12 @@
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
SlicePlugin *CreateSlicePluginDeserialize(const void *buffer, size_t length) {
return new SlicePlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("slice_plugin", CreateSlicePluginDeserialize);
template <typename T>
__global__ void SliceKernel(int num, int dims, const T *input,
const int *offsets_info, T *output) {
......@@ -193,13 +187,13 @@ int SlicePlugin::enqueue(int batch_size, const void *const *inputs,
return cudaGetLastError() != cudaSuccess;
}
size_t SlicePlugin::getSerializationSize() {
size_t SlicePlugin::getSerializationSize() const {
return getBaseSerializationSize() + SerializedSize(getPluginType()) +
SerializedSize(starts_) + SerializedSize(ends_) +
SerializedSize(axes_);
}
void SlicePlugin::serialize(void *buffer) {
void SlicePlugin::serialize(void *buffer) const {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer);
SerializeValue(&buffer, starts_);
......
......@@ -51,12 +51,11 @@ class SlicePlugin : public PluginTensorRT {
#endif
void* workspace, cudaStream_t stream) override;
protected:
size_t getSerializationSize() override;
size_t getSerializationSize() const override;
// TRT will call this func to serialize the configuration of TRT
// It should not be called by users.
void serialize(void* buffer) override;
void serialize(void* buffer) const override;
private:
std::vector<int> starts_;
......@@ -67,6 +66,20 @@ class SlicePlugin : public PluginTensorRT {
cudaStream_t copy_stream_;
};
class SlicePluginCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const override { return "slice_plugin"; }
const char* getPluginVersion() const override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
return new SlicePlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(SlicePluginCreator);
#if IS_TRT_VERSION_GE(6000)
class SlicePluginDynamic : public DynamicPluginTensorRT {
public:
......@@ -79,7 +92,7 @@ class SlicePluginDynamic : public DynamicPluginTensorRT {
SlicePluginDynamic(void const* serialData, size_t serialLength);
const char* getPluginType() const override { return "slice_plugin"; }
const char* getPluginType() const override { return "slice_plugin_dynamic"; }
int getNbOutputs() const override { return 1; }
int initialize() override;
......@@ -125,40 +138,18 @@ class SlicePluginDynamic : public DynamicPluginTensorRT {
cudaStream_t copy_stream_;
};
class SlicePluginDynamicCreator : public nvinfer1::IPluginCreator {
class SlicePluginDynamicCreator : public TensorRTPluginCreator {
public:
SlicePluginDynamicCreator() {}
const char* getPluginName() const override { return "slice_plugin"; }
const char* getPluginName() const override { return "slice_plugin_dynamic"; }
const char* getPluginVersion() const override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serialData,
size_t serialLength) override {
auto plugin = new SlicePluginDynamic(serialData, serialLength);
return plugin;
return new SlicePluginDynamic(serialData, serialLength);
}
void setPluginNamespace(const char* libNamespace) override {
namespace_ = libNamespace;
}
const char* getPluginNamespace() const override { return namespace_.c_str(); }
private:
std::string namespace_;
nvinfer1::PluginFieldCollection field_collection_;
};
REGISTER_TRT_PLUGIN_V2(SlicePluginDynamicCreator);
#endif
......
......@@ -16,7 +16,6 @@
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
......
......@@ -15,7 +15,6 @@
#include <cuda_fp16.h>
#include <algorithm>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
......
......@@ -16,7 +16,6 @@
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
......
......@@ -17,18 +17,12 @@
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
SwishPlugin *CreateSwishPluginDeserialize(const void *buffer, size_t length) {
return new SwishPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("swish_plugin", CreateSwishPluginDeserialize);
int SwishPlugin::initialize() { return 0; }
nvinfer1::Dims SwishPlugin::getOutputDimensions(int index,
......
......@@ -30,22 +30,16 @@ class SwishPlugin : public PluginTensorRT {
private:
float beta_;
protected:
size_t getSerializationSize() override {
return SerializedSize(getPluginType()) + getBaseSerializationSize() +
SerializedSize(beta_);
public:
size_t getSerializationSize() const override {
return getBaseSerializationSize() + SerializedSize(beta_);
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType());
void serialize(void* buffer) const override {
serializeBase(buffer);
SerializeValue(&buffer, beta_);
}
public:
explicit SwishPlugin(const float beta, const bool with_fp16) : beta_(beta) {
with_fp16_ = with_fp16;
}
......@@ -56,7 +50,9 @@ class SwishPlugin : public PluginTensorRT {
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &beta_);
}
~SwishPlugin() {}
int initialize() override;
SwishPlugin* clone() const override {
......@@ -75,6 +71,20 @@ class SwishPlugin : public PluginTensorRT {
void* workspace, cudaStream_t stream) override;
};
class SwishPluginCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const override { return "swish_plugin"; }
const char* getPluginVersion() const override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
return new SwishPlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(SwishPluginCreator);
#if IS_TRT_VERSION_GE(6000)
class SwishPluginDynamic : public DynamicPluginTensorRT {
public:
......@@ -90,7 +100,7 @@ class SwishPluginDynamic : public DynamicPluginTensorRT {
return new SwishPluginDynamic(beta_, with_fp16_);
}
const char* getPluginType() const override { return "swish_plugin"; }
const char* getPluginType() const override { return "swish_plugin_dynamic"; }
int getNbOutputs() const override { return 1; }
int initialize() override;
......@@ -131,44 +141,18 @@ class SwishPluginDynamic : public DynamicPluginTensorRT {
float beta_;
};
class SwishPluginDynamicCreator : public nvinfer1::IPluginCreator {
class SwishPluginDynamicCreator : public TensorRTPluginCreator {
public:
SwishPluginDynamicCreator() {}
const char* getPluginName() const override { return "swish_plugin"; }
const char* getPluginName() const override { return "swish_plugin_dynamic"; }
const char* getPluginVersion() const override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
auto plugin = new SwishPluginDynamic(serial_data, serial_length);
return plugin;
return new SwishPluginDynamic(serial_data, serial_length);
}
void setPluginNamespace(const char* lib_namespace) override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(SwishPluginDynamicCreator);
#endif
......
......@@ -21,10 +21,9 @@ namespace plugin {
inline void Seria(void*& buffer, // NOLINT
const std::vector<nvinfer1::Dims>& input_dims,
size_t max_batch_size, nvinfer1::DataType data_type,
nvinfer1::DataType data_type,
nvinfer1::PluginFormat data_format, bool with_fp16) {
SerializeValue(&buffer, input_dims);
SerializeValue(&buffer, max_batch_size);
SerializeValue(&buffer, data_type);
SerializeValue(&buffer, data_format);
SerializeValue(&buffer, with_fp16);
......@@ -32,37 +31,33 @@ inline void Seria(void*& buffer, // NOLINT
inline void Deseria(void const*& serial_data, size_t& serial_length, // NOLINT
std::vector<nvinfer1::Dims>* input_dims,
size_t* max_batch_size, nvinfer1::DataType* data_type,
nvinfer1::DataType* data_type,
nvinfer1::PluginFormat* data_format, bool* with_fp16) {
DeserializeValue(&serial_data, &serial_length, input_dims);
DeserializeValue(&serial_data, &serial_length, max_batch_size);
DeserializeValue(&serial_data, &serial_length, data_type);
DeserializeValue(&serial_data, &serial_length, data_format);
DeserializeValue(&serial_data, &serial_length, with_fp16);
}
inline size_t SeriaSize(const std::vector<nvinfer1::Dims>& input_dims,
size_t max_batch_size, nvinfer1::DataType data_type,
nvinfer1::DataType data_type,
nvinfer1::PluginFormat data_format, bool with_fp16) {
return (SerializedSize(input_dims) + SerializedSize(max_batch_size) +
SerializedSize(data_type) + SerializedSize(data_format) +
SerializedSize(with_fp16));
return (SerializedSize(input_dims) + SerializedSize(data_type) +
SerializedSize(data_format) + SerializedSize(with_fp16));
}
void PluginTensorRT::serializeBase(void*& buffer) {
Seria(buffer, input_dims_, max_batch_size_, data_type_, data_format_,
with_fp16_);
void PluginTensorRT::serializeBase(void*& buffer) const {
Seria(buffer, input_dims_, data_type_, data_format_, with_fp16_);
}
void PluginTensorRT::deserializeBase(void const*& serial_data,
size_t& serial_length) {
Deseria(serial_data, serial_length, &input_dims_, &max_batch_size_,
&data_type_, &data_format_, &with_fp16_);
Deseria(serial_data, serial_length, &input_dims_, &data_type_, &data_format_,
&with_fp16_);
}
size_t PluginTensorRT::getBaseSerializationSize() {
return SeriaSize(input_dims_, max_batch_size_, data_type_, data_format_,
with_fp16_);
size_t PluginTensorRT::getBaseSerializationSize() const {
return SeriaSize(input_dims_, data_type_, data_format_, with_fp16_);
}
bool PluginTensorRT::supportsFormat(nvinfer1::DataType type,
......@@ -78,23 +73,20 @@ void PluginTensorRT::configureWithFormat(
data_type_ = type;
data_format_ = format;
input_dims_.assign(input_dims, input_dims + num_inputs);
max_batch_size_ = max_batch_size;
}
void PluginTensorRTV2Ext::serializeBase(void*& buffer) const {
Seria(buffer, input_dims_, max_batch_size_, data_type_, data_format_,
with_fp16_);
Seria(buffer, input_dims_, data_type_, data_format_, with_fp16_);
}
void PluginTensorRTV2Ext::deserializeBase(void const*& serial_data,
size_t& serial_length) {
Deseria(serial_data, serial_length, &input_dims_, &max_batch_size_,
&data_type_, &data_format_, &with_fp16_);
Deseria(serial_data, serial_length, &input_dims_, &data_type_, &data_format_,
&with_fp16_);
}
size_t PluginTensorRTV2Ext::getBaseSerializationSize() const {
return SeriaSize(input_dims_, max_batch_size_, data_type_, data_format_,
with_fp16_);
return SeriaSize(input_dims_, data_type_, data_format_, with_fp16_);
}
void PluginTensorRTV2Ext::configurePlugin(
......@@ -105,11 +97,27 @@ void PluginTensorRTV2Ext::configurePlugin(
const bool* output_is_broadcast, nvinfer1::PluginFormat float_format,
int32_t max_batch_size) {
input_dims_.assign(input_dims, input_dims + nb_inputs);
max_batch_size_ = max_batch_size;
data_format_ = float_format;
data_type_ = input_types[0];
}
const nvinfer1::PluginFieldCollection* TensorRTPluginCreator::getFieldNames() {
return &field_collection_;
}
nvinfer1::IPluginV2* TensorRTPluginCreator::createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) {
return nullptr;
}
void TensorRTPluginCreator::setPluginNamespace(const char* lib_namespace) {
plugin_namespace_ = lib_namespace;
}
const char* TensorRTPluginCreator::getPluginNamespace() const {
return plugin_namespace_.c_str();
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
......
......@@ -45,43 +45,55 @@ typedef std::function<PluginTensorRT*(const void*, size_t)>
typedef std::function<PluginTensorRT*(void)> PluginConstructFunc;
// Deprecated. Do not inherit this class, please refer to PluginTensorRTV2Ext
class PluginTensorRT : public nvinfer1::IPluginExt {
class PluginTensorRT : public nvinfer1::IPluginV2 {
public:
PluginTensorRT() : with_fp16_(false) {}
// It was used for TensorRT deserialization.
// It should not be called by users.
PluginTensorRT(const void* serialized_data, size_t length) {}
virtual ~PluginTensorRT() {}
nvinfer1::Dims const& getInputDims(int index) const {
return input_dims_.at(index);
}
size_t getMaxBatchSize() const { return max_batch_size_; }
nvinfer1::DataType getDataType() const { return data_type_; }
nvinfer1::PluginFormat getDataFormat() const { return data_format_; }
virtual const char* getPluginVersion() const { return "1"; }
void AddInput(nvinfer1::ITensor* input) { inputs_.push_back(input); }
std::vector<nvinfer1::ITensor*>& GetInputs() { return inputs_; }
nvinfer1::PluginFormat getDataFormat() const { return data_format_; }
virtual nvinfer1::IPluginExt* clone() const = 0;
// IPluginV2
virtual const char* getPluginType() const = 0;
// Following functions are inherit from nvinfer1::IPluginExt
// Get the number of outputs from the layer
virtual const char* getPluginVersion() const { return "1"; }
int getNbOutputs() const { return 1; }
// Get the dimension of an output tensor
virtual nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims* input_dims,
int num_inputs) = 0;
// Find the workspace size required by the layer
size_t getWorkspaceSize(int) const override { return 0; }
// Check format support. The default is FLOAT32 and kLINEAR.
bool supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const override;
// Configure the layer
void configureWithFormat(const nvinfer1::Dims* input_dims, int num_inputs,
const nvinfer1::Dims* output_dims, int num_outputs,
nvinfer1::DataType type,
nvinfer1::PluginFormat format,
int max_batch_size) override;
// Initialize the layer for execution.
// This is called when the engine is created.
int initialize() override { return 0; }
// Shutdown the layer. This is called when the engine is destroyed
void terminate() override {}
// Find the workspace size required by the layer
size_t getWorkspaceSize(int) const override { return 0; }
// Execute the layer
#if IS_TRT_VERSION_LT(8000)
virtual int enqueue(int batch_size, const void* const* inputs, void** outputs,
......@@ -92,37 +104,39 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
void* workspace, cudaStream_t stream) = 0;
// Find the size of the serialization buffer required
virtual size_t getSerializationSize() = 0;
virtual size_t getSerializationSize() const = 0;
// Serialize the layer config to buffer.
// TensorRT will call this func to serialize the configuration of TensorRT
// engine. It should not be called by users.
virtual void serialize(void* buffer) = 0;
virtual void serialize(void* buffer) const = 0;
// Check format support. The default is FLOAT32 and NCHW.
bool supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const override;
// Configure the layer
void configureWithFormat(const nvinfer1::Dims* input_dims, int num_inputs,
const nvinfer1::Dims* output_dims, int num_outputs,
nvinfer1::DataType type,
nvinfer1::PluginFormat format,
int max_batch_size) override;
void destroy() override { delete this; }
virtual nvinfer1::IPluginV2* clone() const = 0;
void setPluginNamespace(const char* plugin_namespace) override {
namespace_ = plugin_namespace;
}
const char* getPluginNamespace() const override { return namespace_.c_str(); }
protected:
// Deserialize input_dims, max_batch_size, data_type, data_format
void deserializeBase(void const*& serial_data, // NOLINT
size_t& serial_length); // NOLINT
size_t getBaseSerializationSize();
size_t getBaseSerializationSize() const;
// Serialize input_dims, max_batch_size, data_type, data_format
void serializeBase(void*& buffer); // NOLINT
void serializeBase(void*& buffer) const; // NOLINT
std::vector<nvinfer1::Dims> input_dims_;
size_t max_batch_size_;
nvinfer1::DataType data_type_;
nvinfer1::PluginFormat data_format_;
std::vector<nvinfer1::ITensor*> inputs_;
bool with_fp16_;
private:
std::string namespace_;
};
// TensorRT introduced IPluginV2Ext after 5.1, Paddle no longer supports
......@@ -135,7 +149,6 @@ class PluginTensorRTV2Ext : public nvinfer1::IPluginV2Ext {
nvinfer1::Dims const& getInputDims(int index) const {
return input_dims_.at(index);
}
size_t getMaxBatchSize() const { return max_batch_size_; }
nvinfer1::DataType getDataType() const { return data_type_; }
nvinfer1::PluginFormat getDataFormat() const { return data_format_; }
......@@ -228,10 +241,8 @@ class PluginTensorRTV2Ext : public nvinfer1::IPluginV2Ext {
protected:
std::vector<nvinfer1::Dims> input_dims_;
size_t max_batch_size_;
nvinfer1::DataType data_type_;
nvinfer1::PluginFormat data_format_;
std::vector<nvinfer1::ITensor*> inputs_;
bool with_fp16_;
private:
......@@ -305,6 +316,34 @@ class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt {
};
#endif
class TensorRTPluginCreator : public nvinfer1::IPluginCreator {
public:
TensorRTPluginCreator() = default;
virtual const char* getPluginName() const = 0;
virtual const char* getPluginVersion() const = 0;
const nvinfer1::PluginFieldCollection* getFieldNames() override;
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override;
virtual nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) = 0;
void setPluginNamespace(const char* lib_namespace) override;
const char* getPluginNamespace() const override;
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
template <typename T>
class TrtPluginRegistrarV2 {
public:
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name,
const void* serial_data,
size_t serial_length) {
const char* plugin_type;
DeserializeValue(&serial_data, &serial_length, &plugin_type);
PADDLE_ENFORCE_EQ(
Has(plugin_type), true,
platform::errors::NotFound("TensorRT plugin type `%s` does not exists.",
plugin_type));
auto plugin = plugin_registry_[plugin_type](serial_data, serial_length);
owned_plugins_.emplace_back(plugin);
return plugin;
}
bool PluginFactoryTensorRT::RegisterPlugin(
const std::string& op_name, PluginDeserializeFunc deserialize_func) {
if (Has(op_name)) return false;
auto ret = plugin_registry_.emplace(op_name, deserialize_func);
return ret.second;
}
void PluginFactoryTensorRT::DestroyPlugins() { owned_plugins_.clear(); }
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <NvInfer.h>
#include <cstring>
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class PluginFactoryTensorRT : public nvinfer1::IPluginFactory,
public DeleteHelper {
public:
// Deserialization method
PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data,
size_t serial_length) override;
bool RegisterPlugin(const std::string& op_name,
PluginDeserializeFunc deserialize_func);
bool Has(const std::string& op_name) {
return plugin_registry_.find(op_name) != plugin_registry_.end();
}
void DestroyPlugins();
protected:
std::unordered_map<std::string, PluginDeserializeFunc> plugin_registry_;
std::list<std::unique_ptr<PluginTensorRT>> owned_plugins_;
};
class TrtPluginRegistrar {
public:
TrtPluginRegistrar(const std::string& name,
PluginDeserializeFunc deserialize_func) {
inference::Singleton<PluginFactoryTensorRT>::Global().RegisterPlugin(
name, deserialize_func);
}
};
#define REGISTER_TRT_PLUGIN(name, deserialize_func) \
REGISTER_TRT_PLUGIN_UNIQ(__COUNTER__, name, deserialize_func)
#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func) \
static paddle::inference::tensorrt::plugin::TrtPluginRegistrar \
trt_plugin_registrar##ctr UNUSED = \
paddle::inference::tensorrt::plugin::TrtPluginRegistrar( \
name, deserialize_func)
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -17,7 +17,6 @@
#include <algorithm>
#include <cassert>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h"
#include "paddle/fluid/operators/detection/yolo_box_op.h"
......
......@@ -35,4 +35,5 @@ set_tests_properties(test_trt_activation_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_trt_conv_pass PROPERTIES TIMEOUT 120)
#set_tests_properties(test_trt_multiclass_nms_op PROPERTIES TIMEOUT 200)
set_tests_properties(test_trt_dynamic_shape PROPERTIES TIMEOUT 120)
set_tests_properties(test_trt_pool_op PROPERTIES ENVIRONMENT FLAGS_fraction_of_gpu_memory_to_use=0.1 TIMEOUT 45)
endif()
......@@ -33,11 +33,11 @@ class TensorRTSubgraphPassActivationTest(InferencePassTest):
self.setUpTensorRTParam()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 6, 64, 64], dtype="float32")
name="data", shape=[-1, 6, 32, 32], dtype="float32")
act_out = self.append_act(data)
out = fluid.layers.batch_norm(act_out, is_test=True)
self.feeds = {
"data": np.random.random([1, 6, 64, 64]).astype("float32"),
"data": np.random.random([1, 6, 32, 32]).astype("float32"),
}
self.fetch_list = [out]
......@@ -154,6 +154,71 @@ class TensorRTSubgraphPassPreluElementTest(TensorRTSubgraphPassActivationTest):
return fluid.layers.prelu(x, mode='element')
class TensorRTSubgraphPassPreluDynamicTest(TensorRTSubgraphPassActivationTest):
def setUpTensorRTParam(self):
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassActivationTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.dynamic_shape_params = TensorRTSubgraphPassActivationTest.DynamicShapeParam(
{
'data': [1, 6, 8, 8]
}, {'data': [1, 6, 512, 512]}, {'data': [1, 6, 256, 256]}, False)
def append_act(self, x):
return fluid.layers.prelu(x, mode='all')
class TensorRTSubgraphPassPreluFp16Test(TensorRTSubgraphPassActivationTest):
def setUpTensorRTParam(self):
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassActivationTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Half, False, False)
def append_act(self, x):
return fluid.layers.prelu(x, mode='all')
class TensorRTSubgraphPassPreluFp16SerializeTest(
TensorRTSubgraphPassActivationTest):
def setUpTensorRTParam(self):
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassActivationTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Half, True, False)
def append_act(self, x):
return fluid.layers.prelu(x, mode='all')
class TensorRTSubgraphPassPreluFp16DynamicTest(
TensorRTSubgraphPassActivationTest):
def setUpTensorRTParam(self):
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassActivationTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Half, False, False)
self.dynamic_shape_params = TensorRTSubgraphPassActivationTest.DynamicShapeParam(
{
'data': [1, 6, 8, 8]
}, {'data': [1, 6, 512, 512]}, {'data': [1, 6, 256, 256]}, False)
def append_act(self, x):
return fluid.layers.prelu(x, mode='all')
class TensorRTSubgraphPassPreluFp16DynamicSerializeTest(
TensorRTSubgraphPassActivationTest):
def setUpTensorRTParam(self):
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassActivationTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Half, True, False)
self.dynamic_shape_params = TensorRTSubgraphPassActivationTest.DynamicShapeParam(
{
'data': [1, 6, 8, 8]
}, {'data': [1, 6, 512, 512]}, {'data': [1, 6, 256, 256]}, False)
def append_act(self, x):
return fluid.layers.prelu(x, mode='all')
class TensorRTSubgraphPassGeluTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.gelu(x)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import shutil
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
from paddle.fluid.core import AnalysisConfig
class TensorRTSubgraphPassElementwiseBroadcastTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data1 = fluid.data(
name="data1", shape=[-1, 3, 64, 64], dtype="float32")
data2 = fluid.data(
name="data2", shape=[-1, 3, 64, 1], dtype="float32")
eltwise_out = self.append_eltwise(data1, data2)
out = fluid.layers.batch_norm(eltwise_out, is_test=True)
self.feeds = {
"data1": np.random.random([1, 3, 64, 64]).astype("float32"),
"data2": np.random.random([1, 3, 64, 1]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassElementwiseBroadcastTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, True, False)
self.fetch_list = [out]
def append_eltwise(self, data1, data2):
return fluid.layers.elementwise_add(x=data1, y=data2, axis=0)
def test_check_output(self):
if os.path.exists(self.path + "_opt_cache"):
shutil.rmtree(self.path + "_opt_cache")
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import shutil
import unittest
import itertools
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
from paddle.fluid.core import AnalysisConfig
class TRTInstanceNormTest(InferencePassTest):
def setUp(self):
self.bs = 4
self.channel = 4
self.height = 8
self.width = 8
self.precision = AnalysisConfig.Precision.Float32
self.serialize = False
self.enable_trt = True
def build(self):
self.trt_parameters = InferencePassTest.TensorRTParam(
1 << 30, self.bs, 2, self.precision, self.serialize, False)
with fluid.program_guard(self.main_program, self.startup_program):
shape = [-1, self.channel, self.height, self.width]
data = fluid.data(name='in', shape=shape, dtype='float32')
instance_norm_out = fluid.layers.instance_norm(data)
out = fluid.layers.batch_norm(instance_norm_out, is_test=True)
shape[0] = self.bs
self.feeds = {'in': np.random.random(shape).astype('float32'), }
self.fetch_list = [out]
def check_output(self, remove_cache=False):
if remove_cache and os.path.exists(self.path + "_opt_cache"):
shutil.rmtree(self.path + "_opt_cache")
if core.is_compiled_with_cuda():
use_gpu = True
atol = 1e-5
if self.trt_parameters.precision == AnalysisConfig.Precision.Half:
atol = 2e-2
self.check_output_with_option(use_gpu, atol, flatten=True)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
def run_test(self, remove_cache=False):
self.build()
self.check_output(remove_cache)
def run_all_tests(self):
precision_opt = [
AnalysisConfig.Precision.Float32, AnalysisConfig.Precision.Half
]
serialize_opt = [False, True]
for precision, serialize in itertools.product(precision_opt,
serialize_opt):
self.precision = precision
self.serialize = serialize
self.run_test()
def test_base(self):
self.run_test()
def test_fp16(self):
self.precision = AnalysisConfig.Precision.Half
self.run_test()
def test_serialize(self):
self.serialize = True
self.run_test(remove_cache=True)
def test_all(self):
self.run_all_tests()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import unittest
import itertools
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
from paddle.fluid.core import AnalysisConfig
class TensorRTPoolTest(InferencePassTest):
def setUp(self):
self.bs = 1
self.channel = 3
self.height = 8
self.width = 8
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 0
self.global_pooling = False
self.ceil_mode = False
self.exclusive = False
self.enable_trt = True
self.serialize = False
self.precision = AnalysisConfig.Precision.Float32
self.feeds = {
'data':
np.random.random([self.bs, self.channel, self.height,
self.width]).astype('float32'),
}
def set_extra_config(self):
pass
def build_network(self):
self.set_extra_config()
self.trt_parameters = TensorRTPoolTest.TensorRTParam(
1 << 30, self.bs, 0, self.precision, self.serialize, False)
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name='data',
shape=[-1, self.channel, self.height, self.width],
dtype='float32')
pool_out = fluid.layers.pool2d(
input=data,
pool_size=self.pool_size,
pool_type=self.pool_type,
pool_stride=self.pool_stride,
pool_padding=self.pool_padding,
global_pooling=self.global_pooling,
ceil_mode=self.ceil_mode,
exclusive=self.exclusive)
out = fluid.layers.batch_norm(pool_out, is_test=True)
self.fetch_list = [out]
def check_output(self):
if os.path.exists(self.path + "_opt_cache"):
shutil.rmtree(self.path + "_opt_cache")
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
def run_test(self):
self.build_network()
self.check_output()
def test(self):
precision_options = [
AnalysisConfig.Precision.Float32, AnalysisConfig.Precision.Half
]
serialize_options = [False, True]
dynamic_shape_profile = InferencePassTest.DynamicShapeParam(
{
'data':
[self.bs, self.channel, self.height // 2, self.width // 2]
}, {'data': [self.bs, self.channel, self.height, self.width]},
{'data': [self.bs, self.channel, self.height, self.width]}, False)
dynamic_shape_options = [None, dynamic_shape_profile]
for precision, serialize, dynamic_shape in itertools.product(
precision_options, serialize_options, dynamic_shape_options):
is_dynamic = True if dynamic_shape_options is not None else False
with self.subTest('Precision: {}, Serialize: {}, Dynamic: {}'.
format(precision, serialize, is_dynamic)):
self.precision = precision
self.serialize = serialize
self.dynamic_shape = dynamic_shape
self.run_test()
class TensorRTAvgPoolTest(TensorRTPoolTest):
def set_extra_config(self):
self.pool_size = 2
self.pool_type = 'avg'
self.pool_stride = 1
self.pool_padding = 0
self.global_pooling = False
self.ceil_mode = False
self.exclusive = False
class TensorRTGlobalPoolTest(TensorRTPoolTest):
def set_extra_config(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 0
self.global_pooling = True
self.ceil_mode = False
self.exclusive = False
class TensorRTCeilPoolTest(TensorRTPoolTest):
def set_extra_config(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 0
self.global_pooling = False
self.ceil_mode = True
self.exclusive = False
class TensorRTExclusivePoolTest(TensorRTPoolTest):
def set_extra_config(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 0
self.global_pooling = False
self.ceil_mode = False
self.exclusive = True
class TensorRTSamePaddingPoolTest(InferencePassTest):
def set_extra_config(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 'SAME'
self.global_pooling = False
self.ceil_mode = False
self.exclusive = False
class TensorRTValidPaddingPoolTest(InferencePassTest):
def set_extra_config(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 'VALID'
self.global_pooling = False
self.ceil_mode = False
self.exclusive = False
if __name__ == "__main__":
unittest.main()
......@@ -47,113 +47,6 @@ class TensorRTSubgraphPassFcTest(InferencePassTest):
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassPoolTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 6, 64, 64], dtype="float32")
pool_out = fluid.layers.pool2d(
input=data,
pool_size=self.pool_size,
pool_type=self.pool_type,
pool_stride=self.pool_stride,
pool_padding=self.pool_padding,
global_pooling=self.global_pooling,
ceil_mode=self.ceil_mode,
exclusive=self.exclusive)
out = fluid.layers.batch_norm(pool_out, is_test=True)
self.feeds = {
"data": np.random.random([1, 6, 64, 64]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassPoolTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def set_params(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 0
self.global_pooling = False
self.ceil_mode = False
self.exclusive = False
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassAvgPoolTest(TensorRTSubgraphPassPoolTest):
def set_params(self):
self.pool_size = 2
self.pool_type = 'avg'
self.pool_stride = 1
self.pool_padding = 0
self.global_pooling = False
self.ceil_mode = False
self.exclusive = False
class TensorRTSubgraphPassGlobalPoolTest(TensorRTSubgraphPassPoolTest):
def set_params(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 0
self.global_pooling = True
self.ceil_mode = False
self.exclusive = False
class TensorRTSubgraphPassCeilPoolTest(TensorRTSubgraphPassPoolTest):
def set_params(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 0
self.global_pooling = False
self.ceil_mode = True
self.exclusive = False
class TensorRTSubgraphPassExclusivePoolTest(TensorRTSubgraphPassPoolTest):
def set_params(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 0
self.global_pooling = False
self.ceil_mode = False
self.exclusive = True
class TensorRTSubgraphPassSamePaddingPoolTest(InferencePassTest):
def set_params(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 'SAME'
self.global_pooling = False
self.ceil_mode = False
self.exclusive = False
class TensorRTSubgraphPassValidPaddingPoolTest(InferencePassTest):
def set_params(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 'VALID'
self.global_pooling = False
self.ceil_mode = False
self.exclusive = False
class TensorRTSubgraphPassConcatTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册