未验证 提交 c59c8e4f 编写于 作者: 提交者: GitHub

[inference]add hard_swish dynamic plugin (#35214)

上级 d43f797a
......@@ -63,11 +63,23 @@ class HardSwishOpConverter : public OpConverter {
engine_, ElementWise, *input, *(hsig_layer->getOutput(0)),
nvinfer1::ElementWiseOperation::kPROD);
layer = eltwise_layer;
} else {
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
plugin::HardSwishPluginDynamic* plugin =
new plugin::HardSwishPluginDynamic(threshold, scale, offset);
layer = engine_->AddDynamicPlugin(&input, input_num, plugin);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
} else {
plugin::HardSwishPlugin* plugin =
new plugin::HardSwishPlugin(threshold, scale, offset);
layer = engine_->AddPlugin(&input, input_num, plugin);
}
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "hard_swish", {output_name}, test_mode);
}
......
......@@ -22,10 +22,10 @@ namespace tensorrt {
namespace plugin {
nvinfer1::Dims HardSwishPlugin::getOutputDimensions(
int index, const nvinfer1::Dims* in_dims, int nb_inputs) TRT_NOEXCEPT {
int index, const nvinfer1::Dims *in_dims, int nb_inputs) TRT_NOEXCEPT {
assert(nb_inputs == 1);
assert(index < this->getNbOutputs());
nvinfer1::Dims const& input_dims = in_dims[0];
nvinfer1::Dims const &input_dims = in_dims[0];
nvinfer1::Dims output_dims = input_dims;
return output_dims;
}
......@@ -42,7 +42,7 @@ __device__ T kMin(T a, T b) {
template <typename T, unsigned TPB>
__global__ void hard_swish_kernel(float threshold, float scale, float offset,
int n, const T* input, T* output) {
int n, const T *input, T *output) {
const int idx = blockIdx.x * TPB + threadIdx.x;
if (idx < n) {
const T in = input[idx];
......@@ -50,14 +50,14 @@ __global__ void hard_swish_kernel(float threshold, float scale, float offset,
}
}
int HardSwishPlugin::enqueue(int batch_size, const void* const* inputs,
int HardSwishPlugin::enqueue(int batch_size, const void *const *inputs,
#if IS_TRT_VERSION_LT(8000)
void** outputs, void*, cudaStream_t stream) {
void **outputs, void *, cudaStream_t stream) {
#else
void* const* outputs, void*,
void *const *outputs, void *,
cudaStream_t stream) TRT_NOEXCEPT {
#endif
const auto& input_dims = this->getInputDims(0);
const auto &input_dims = this->getInputDims(0);
int num = batch_size;
for (int i = 0; i < input_dims.nbDims; i++) {
num *= input_dims.d[i];
......@@ -69,14 +69,79 @@ int HardSwishPlugin::enqueue(int batch_size, const void* const* inputs,
const int block_size = 256;
const int grid_size = (num + block_size - 1) / block_size;
const float* input = static_cast<const float*>(inputs[0]);
float* output = static_cast<float*>(outputs[0]);
const float *input = static_cast<const float *>(inputs[0]);
float *output = static_cast<float *>(outputs[0]);
hard_swish_kernel<float, block_size><<<grid_size, block_size, 0, stream>>>(
threshold, scale, offset, num, input, output);
return cudaGetLastError() != cudaSuccess;
}
#if IS_TRT_VERSION_GE(6000)
nvinfer1::DimsExprs HardSwishPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
return inputs[0];
}
int HardSwishPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims;
int num = 1;
for (int i = 0; i < input_dims.nbDims; i++) {
num *= input_dims.d[i];
}
float threshold = threshold_;
float scale = scale_;
float offset = offset_;
const int block_size = 256;
const int grid_size = (num + block_size - 1) / block_size;
const float *input = static_cast<const float *>(inputs[0]);
float *output = static_cast<float *>(outputs[0]);
hard_swish_kernel<float, block_size><<<grid_size, block_size, 0, stream>>>(
threshold, scale, offset, num, input, output);
return cudaGetLastError() != cudaSuccess;
}
nvinfer1::DataType HardSwishPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *input_types,
int nb_inputs) const TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(index, 0,
platform::errors::InvalidArgument(
"The Elementwise Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
return input_types[0];
}
bool HardSwishPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
(in_out && pos < (nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
......
......@@ -94,6 +94,113 @@ class HardSwishPluginCreator : public TensorRTPluginCreator {
};
REGISTER_TRT_PLUGIN_V2(HardSwishPluginCreator);
#if IS_TRT_VERSION_GE(6000)
class HardSwishPluginDynamic : public DynamicPluginTensorRT {
public:
HardSwishPluginDynamic(const float threshold, const float scale,
const float offset)
: threshold_(threshold), scale_(scale), offset_(offset) {}
// It was used for tensorrt deserialization.
// It should not be called by users.
HardSwishPluginDynamic(void const* serialData, size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &threshold_);
DeserializeValue(&serialData, &serialLength, &scale_);
DeserializeValue(&serialData, &serialLength, &offset_);
}
~HardSwishPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new HardSwishPluginDynamic(threshold_, scale_, offset_);
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "hard_swish_plugin_dynamic";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT override { return 0; }
nvinfer1::DimsExprs getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(threshold_) + SerializedSize(scale_) +
SerializedSize(offset_);
}
// TRT will call this func to serialize the configuration of TRT
// It should not be called by users.
void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, threshold_);
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, offset_);
}
nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType* inputTypes,
int nbInputs) const TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) TRT_NOEXCEPT override {}
void destroy() TRT_NOEXCEPT override { delete this; }
protected:
float threshold_;
float scale_;
float offset_;
};
class HardSwishPluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
HardSwishPluginDynamicCreator() {}
const char* getPluginName() const TRT_NOEXCEPT override {
return "hardswish_plugin_dynamic";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(const char* name,
const nvinfer1::PluginFieldCollection* fc)
TRT_NOEXCEPT override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(
const char* name, const void* serial_data,
size_t serial_length) TRT_NOEXCEPT override {
auto plugin = new HardSwishPluginDynamic(serial_data, serial_length);
return plugin;
}
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const TRT_NOEXCEPT override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(HardSwishPluginDynamicCreator);
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
class TrtConvertHardSwishTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
inputs = program_config.inputs
weights = program_config.weights
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
if attrs[0]['threshold'] <= 0 or attrs[0]['scale'] <= 0:
return False
return True
def sample_program_configs(self):
def generate_input1(attrs: List[Dict[str, Any]]):
return np.ones([1, 3, 64, 64]).astype(np.float32)
for threshold in [6.0, 7.0, 100.0, 0.0, -1.0]:
for scale in [5.0, 6.0, 7.0, -1.0, 0.0, 100.0]:
for offset in [3.0, 4.0, 5.0, -1.0, 0.0, 100.0]:
dics = [{
"threshold": threshold,
"scale": scale,
"offset": offset
}]
ops_config = [{
"op_type": "hard_swish",
"op_inputs": {
"X": ["input_data"]
},
"op_outputs": {
"Out": ["hard_swish_output_data"]
},
"op_attrs": dics[0]
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(data_gen=partial(
generate_input1, dics))
},
outputs=["hard_swish_output_data"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]}
self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]}
self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 2
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), (1e-5, 1e-5)
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), (1e-5, 1e-5)
def test(self):
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册