未验证 提交 a1abb7c9 编写于 作者: W wenbin 提交者: GitHub

swish refactor (#42610)

* swish refactor

* bug fix

* trt7 non-linear bug fix
上级 29a6b8c9
...@@ -75,7 +75,7 @@ class SwishOpConverter : public OpConverter { ...@@ -75,7 +75,7 @@ class SwishOpConverter : public OpConverter {
bool with_fp16 = bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::SwishPlugin* plugin = new plugin::SwishPlugin(beta, with_fp16); plugin::SwishPlugin* plugin = new plugin::SwishPlugin(beta, with_fp16);
layer = engine_->AddPlugin(&input, input_num, plugin); layer = engine_->AddPluginV2Ext(&input, input_num, plugin);
} }
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
......
...@@ -24,6 +24,16 @@ namespace tensorrt { ...@@ -24,6 +24,16 @@ namespace tensorrt {
namespace plugin { namespace plugin {
int SwishPlugin::initialize() TRT_NOEXCEPT { return 0; } int SwishPlugin::initialize() TRT_NOEXCEPT { return 0; }
void SwishPlugin::terminate() TRT_NOEXCEPT {}
bool SwishPlugin::supportsFormat(
nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT {
if (with_fp16_) {
return type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF;
}
return type == nvinfer1::DataType::kFLOAT;
}
nvinfer1::Dims SwishPlugin::getOutputDimensions(int index, nvinfer1::Dims SwishPlugin::getOutputDimensions(int index,
const nvinfer1::Dims *inputDims, const nvinfer1::Dims *inputDims,
...@@ -85,17 +95,29 @@ int SwishPlugin::enqueue(int batch_size, const void *const *inputs, ...@@ -85,17 +95,29 @@ int SwishPlugin::enqueue(int batch_size, const void *const *inputs,
void *const *outputs, void *workspace, void *const *outputs, void *workspace,
cudaStream_t stream) TRT_NOEXCEPT { cudaStream_t stream) TRT_NOEXCEPT {
#endif #endif
// input dims is CHW.
const auto &input_dims = this->getInputDims(0); const auto &input_dims = this->getInputDims(0);
const float *input = reinterpret_cast<const float *>(inputs[0]);
float *output = reinterpret_cast<float *const *>(outputs)[0];
int num = batch_size; int num = batch_size;
for (int i = 0; i < input_dims.nbDims; i++) { for (int i = 0; i < input_dims.nbDims; i++) {
num *= input_dims.d[i]; num *= input_dims.d[i];
} }
int threads = 1024; int threads = 1024;
int blocks = (num + threads - 1) / threads; int blocks = (num + threads - 1) / threads;
auto type = getDataType();
if (type == nvinfer1::DataType::kFLOAT) {
VLOG(1) << "TRT Plugin DataType selected. Swish-->fp32";
const float *input = reinterpret_cast<const float *>(inputs[0]);
float *output = reinterpret_cast<float *const *>(outputs)[0];
swish_kernel<<<blocks, threads, 0, stream>>>(num, input, output, beta_); swish_kernel<<<blocks, threads, 0, stream>>>(num, input, output, beta_);
} else if (type == nvinfer1::DataType::kHALF) {
VLOG(1) << "TRT Plugin DataType selected. Swish-->fp16";
const half *input = reinterpret_cast<const half *>(inputs[0]);
half *output = reinterpret_cast<half *const *>(outputs)[0];
swish_kernel<<<blocks, threads, 0, stream>>>(num, input, output,
(half)beta_);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The Swish TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
...@@ -140,12 +162,15 @@ bool SwishPluginDynamic::supportsFormatCombination( ...@@ -140,12 +162,15 @@ bool SwishPluginDynamic::supportsFormatCombination(
const nvinfer1::PluginTensorDesc &in = in_out[pos]; const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) { if (pos == 0) {
if (with_fp16_) { if (with_fp16_) {
return (in.type == nvinfer1::DataType::kFLOAT || bool res = (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) && in.type == nvinfer1::DataType::kHALF);
(in.format == nvinfer1::TensorFormat::kLINEAR); // encounter trt crash bug
#if IS_TRT_VERSION_LT(8000)
res = res && (in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
return res;
} else { } else {
return (in.type == nvinfer1::DataType::kFLOAT) && return in.type == nvinfer1::DataType::kFLOAT;
(in.format == nvinfer1::TensorFormat::kLINEAR);
} }
} }
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
......
...@@ -26,7 +26,7 @@ namespace inference { ...@@ -26,7 +26,7 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
class SwishPlugin : public PluginTensorRT { class SwishPlugin : public PluginTensorRTV2Ext {
private: private:
float beta_; float beta_;
...@@ -55,13 +55,24 @@ class SwishPlugin : public PluginTensorRT { ...@@ -55,13 +55,24 @@ class SwishPlugin : public PluginTensorRT {
int initialize() TRT_NOEXCEPT override; int initialize() TRT_NOEXCEPT override;
SwishPlugin* clone() const TRT_NOEXCEPT override { nvinfer1::IPluginV2Ext* clone() const TRT_NOEXCEPT override {
return new SwishPlugin(beta_, with_fp16_); auto* plugin = new SwishPlugin(beta_, with_fp16_);
plugin->data_format_ = data_format_;
plugin->data_type_ = data_type_;
plugin->input_dims_ = input_dims_;
return plugin;
} }
const char* getPluginType() const TRT_NOEXCEPT override { const char* getPluginType() const TRT_NOEXCEPT override {
return "swish_plugin"; return "swish_plugin";
} }
nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT override {
return input_types[0];
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; } int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
int nbInputDims) TRT_NOEXCEPT override; int nbInputDims) TRT_NOEXCEPT override;
...@@ -71,6 +82,12 @@ class SwishPlugin : public PluginTensorRT { ...@@ -71,6 +82,12 @@ class SwishPlugin : public PluginTensorRT {
int enqueue(int batchSize, const void* const* inputs, void* const* outputs, int enqueue(int batchSize, const void* const* inputs, void* const* outputs,
#endif #endif
void* workspace, cudaStream_t stream) TRT_NOEXCEPT override; void* workspace, cudaStream_t stream) TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
const char* getPluginVersion() const TRT_NOEXCEPT override { return "2"; }
bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format)
const TRT_NOEXCEPT override;
}; };
class SwishPluginCreator : public TensorRTPluginCreator { class SwishPluginCreator : public TensorRTPluginCreator {
...@@ -79,7 +96,7 @@ class SwishPluginCreator : public TensorRTPluginCreator { ...@@ -79,7 +96,7 @@ class SwishPluginCreator : public TensorRTPluginCreator {
return "swish_plugin"; return "swish_plugin";
} }
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } const char* getPluginVersion() const TRT_NOEXCEPT override { return "2"; }
nvinfer1::IPluginV2* deserializePlugin( nvinfer1::IPluginV2* deserializePlugin(
const char* name, const void* serial_data, const char* name, const void* serial_data,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册