未验证 提交 a1abb7c9 编写于 作者: W wenbin 提交者: GitHub

swish refactor (#42610)

* swish refactor

* bug fix

* trt7 non-linear bug fix
上级 29a6b8c9
......@@ -75,7 +75,7 @@ class SwishOpConverter : public OpConverter {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::SwishPlugin* plugin = new plugin::SwishPlugin(beta, with_fp16);
layer = engine_->AddPlugin(&input, input_num, plugin);
layer = engine_->AddPluginV2Ext(&input, input_num, plugin);
}
auto output_name = op_desc.Output("Out")[0];
......
......@@ -24,6 +24,16 @@ namespace tensorrt {
namespace plugin {
int SwishPlugin::initialize() TRT_NOEXCEPT { return 0; }
void SwishPlugin::terminate() TRT_NOEXCEPT {}
bool SwishPlugin::supportsFormat(
nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT {
if (with_fp16_) {
return type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF;
}
return type == nvinfer1::DataType::kFLOAT;
}
nvinfer1::Dims SwishPlugin::getOutputDimensions(int index,
const nvinfer1::Dims *inputDims,
......@@ -85,17 +95,29 @@ int SwishPlugin::enqueue(int batch_size, const void *const *inputs,
void *const *outputs, void *workspace,
cudaStream_t stream) TRT_NOEXCEPT {
#endif
// input dims is CHW.
const auto &input_dims = this->getInputDims(0);
const float *input = reinterpret_cast<const float *>(inputs[0]);
float *output = reinterpret_cast<float *const *>(outputs)[0];
int num = batch_size;
for (int i = 0; i < input_dims.nbDims; i++) {
num *= input_dims.d[i];
}
int threads = 1024;
int blocks = (num + threads - 1) / threads;
swish_kernel<<<blocks, threads, 0, stream>>>(num, input, output, beta_);
auto type = getDataType();
if (type == nvinfer1::DataType::kFLOAT) {
VLOG(1) << "TRT Plugin DataType selected. Swish-->fp32";
const float *input = reinterpret_cast<const float *>(inputs[0]);
float *output = reinterpret_cast<float *const *>(outputs)[0];
swish_kernel<<<blocks, threads, 0, stream>>>(num, input, output, beta_);
} else if (type == nvinfer1::DataType::kHALF) {
VLOG(1) << "TRT Plugin DataType selected. Swish-->fp16";
const half *input = reinterpret_cast<const half *>(inputs[0]);
half *output = reinterpret_cast<half *const *>(outputs)[0];
swish_kernel<<<blocks, threads, 0, stream>>>(num, input, output,
(half)beta_);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The Swish TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess;
}
......@@ -140,12 +162,15 @@ bool SwishPluginDynamic::supportsFormatCombination(
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
if (with_fp16_) {
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
bool res = (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF);
// encounter trt crash bug
#if IS_TRT_VERSION_LT(8000)
res = res && (in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
return res;
} else {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
return in.type == nvinfer1::DataType::kFLOAT;
}
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
......
......@@ -26,7 +26,7 @@ namespace inference {
namespace tensorrt {
namespace plugin {
class SwishPlugin : public PluginTensorRT {
class SwishPlugin : public PluginTensorRTV2Ext {
private:
float beta_;
......@@ -55,13 +55,24 @@ class SwishPlugin : public PluginTensorRT {
int initialize() TRT_NOEXCEPT override;
SwishPlugin* clone() const TRT_NOEXCEPT override {
return new SwishPlugin(beta_, with_fp16_);
nvinfer1::IPluginV2Ext* clone() const TRT_NOEXCEPT override {
auto* plugin = new SwishPlugin(beta_, with_fp16_);
plugin->data_format_ = data_format_;
plugin->data_type_ = data_type_;
plugin->input_dims_ = input_dims_;
return plugin;
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "swish_plugin";
}
nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT override {
return input_types[0];
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
int nbInputDims) TRT_NOEXCEPT override;
......@@ -71,6 +82,12 @@ class SwishPlugin : public PluginTensorRT {
int enqueue(int batchSize, const void* const* inputs, void* const* outputs,
#endif
void* workspace, cudaStream_t stream) TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
const char* getPluginVersion() const TRT_NOEXCEPT override { return "2"; }
bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format)
const TRT_NOEXCEPT override;
};
class SwishPluginCreator : public TensorRTPluginCreator {
......@@ -79,7 +96,7 @@ class SwishPluginCreator : public TensorRTPluginCreator {
return "swish_plugin";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
const char* getPluginVersion() const TRT_NOEXCEPT override { return "2"; }
nvinfer1::IPluginV2* deserializePlugin(
const char* name, const void* serial_data,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册