未验证 提交 c528f1d4 编写于 作者: P Pei Yang 提交者: GitHub

[Paddle-TRT] Add hard_sigmoid and hard_swish support(support MobilenetV3) (#23672)

* add hard_sigmoid trt op converter

* add hard_swish op converter and plugin. test=develop

* add macro to adapt lower trt version. test=develop
上级 015acdbf
......@@ -961,6 +961,8 @@ USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER(concat);
USE_TRT_CONVERTER(dropout);
USE_TRT_CONVERTER(pad);
USE_TRT_CONVERTER(hard_sigmoid);
USE_TRT_CONVERTER(hard_swish);
USE_TRT_CONVERTER(split);
USE_TRT_CONVERTER(prelu);
USE_TRT_CONVERTER(conv2d_transpose);
......
......@@ -4,7 +4,7 @@ nv_library(tensorrt_converter
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc hard_sigmoid_op.cc hard_swish_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* HardSigmoidOp, IActivationLayer in TRT. This Layer doesn't has weights.
*/
class HardSigmoidOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
#if IS_TRT_VERSION_GE(5000)
VLOG(3) << "convert a fluid HardSigmoid op to tensorrt IActivationLayer "
"layer without bias";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
float slope = boost::get<float>(op_desc.GetAttr("slope"));
float offset = boost::get<float>(op_desc.GetAttr("offset"));
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Activation, *input,
nvinfer1::ActivationType::kHARD_SIGMOID);
layer->setAlpha(slope);
layer->setBeta(offset);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "hard_sigmoid", {output_name}, test_mode);
#else
PADDLE_THROW(platform::errors::Fatal(
"Hard sigmoid TRT converter is only supported on TRT 5 or higher. "
"Please confirm your TRT version is no less than 5.0."));
#endif
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(hard_sigmoid, HardSigmoidOpConverter);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* HardSwish converter from fluid to tensorRT.
*/
class HardSwishOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert fluid HardSwish op to tensorrt HardSwish plugin";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
int input_num = op_desc.Input("X").size();
PADDLE_ENFORCE_EQ(
input_num, 1,
platform::errors::InvalidArgument(
"HardSwish op has only 1 input, but got %d", input_num));
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
// Get output
size_t output_num = op_desc.Output("Out").size();
PADDLE_ENFORCE_EQ(
output_num, 1,
platform::errors::InvalidArgument(
"HardSwish op has only 1 output, but got %d", output_num));
const float threshold =
op_desc.HasAttr("threshold")
? boost::get<float>(op_desc.GetAttr("threshold"))
: 6.0f;
const float scale = op_desc.HasAttr("scale")
? boost::get<float>(op_desc.GetAttr("scale"))
: 6.0f;
const float offset = op_desc.HasAttr("offset")
? boost::get<float>(op_desc.GetAttr("offset"))
: 3.0f;
nvinfer1::ILayer* layer = nullptr;
plugin::HardSwishPlugin* plugin =
new plugin::HardSwishPlugin(threshold, scale, offset);
layer = engine_->AddPlugin(&input, input_num, plugin);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "hard_swish", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(hard_swish, HardSwishOpConverter);
......@@ -23,6 +23,7 @@ struct SimpleOpTypeSetTeller : public Teller {
SimpleOpTypeSetTeller() {
#if IS_TRT_VERSION_GE(5130)
teller_set.insert("relu6");
teller_set.insert("hard_sigmoid");
#endif
#if IS_TRT_VERSION_GE(6000)
teller_set.insert("fused_embedding_eltwise_layernorm");
......@@ -54,6 +55,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"relu",
"softmax",
"sigmoid",
"hard_swish",
"depthwise_conv2d",
"batch_norm",
"concat",
......
......@@ -3,5 +3,5 @@ nv_library(tensorrt_plugin
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu hard_swish_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include "paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
HardSwishPlugin* CreateHardSwishPluginDeserialize(const void* buffer,
size_t length) {
return new HardSwishPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("hard_swish_plugin", CreateHardSwishPluginDeserialize);
nvinfer1::Dims HardSwishPlugin::getOutputDimensions(
int index, const nvinfer1::Dims* in_dims, int nb_inputs) {
assert(nb_inputs == 1);
assert(index < this->getNbOutputs());
nvinfer1::Dims const& input_dims = in_dims[0];
nvinfer1::Dims output_dims = input_dims;
return output_dims;
}
template <typename T>
__device__ T kMax(T a, T b) {
return a > b ? a : b;
}
template <typename T>
__device__ T kMin(T a, T b) {
return a < b ? a : b;
}
template <typename T, unsigned TPB>
__global__ void hard_swish_kernel(float threshold, float scale, float offset,
int n, const T* input, T* output) {
const int idx = blockIdx.x * TPB + threadIdx.x;
if (idx < n) {
const T in = input[idx];
output[idx] = in / scale * kMin<T>(kMax<T>(in + offset, 0), threshold);
}
}
int HardSwishPlugin::enqueue(int batch_size, const void* const* inputs,
void** outputs, void*, cudaStream_t stream) {
const auto& input_dims = this->getInputDims(0);
int num = batch_size;
for (int i = 0; i < input_dims.nbDims; i++) {
num *= input_dims.d[i];
}
float threshold = threshold_;
float scale = scale_;
float offset = offset_;
const int block_size = 256;
const int grid_size = (num + block_size - 1) / block_size;
const float* input = static_cast<const float*>(inputs[0]);
float* output = static_cast<float*>(outputs[0]);
hard_swish_kernel<float, block_size><<<grid_size, block_size, 0, stream>>>(
threshold, scale, offset, num, input, output);
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class HardSwishPlugin : public PluginTensorRT {
public:
HardSwishPlugin(const float threshold, const float scale, const float offset)
: threshold_(threshold), scale_(scale), offset_(offset) {}
// It was used for tensorrt deserialization.
// It should not be called by users.
HardSwishPlugin(void const* serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &threshold_);
DeserializeValue(&serialData, &serialLength, &scale_);
DeserializeValue(&serialData, &serialLength, &offset_);
}
~HardSwishPlugin() {}
HardSwishPlugin* clone() const override {
return new HardSwishPlugin(threshold_, scale_, offset_);
}
const char* getPluginType() const override { return "hard_swish_plugin"; }
int getNbOutputs() const override { return 1; }
int initialize() override { return 0; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
int nbInputDims) override;
int enqueue(int batchSize, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) override;
protected:
float threshold_;
float scale_;
float offset_;
size_t getSerializationSize() override {
return getBaseSerializationSize() + SerializedSize(threshold_) +
SerializedSize(scale_) + SerializedSize(offset_) +
SerializedSize(getPluginType());
}
// TRT will call this func to serialize the configuration of TRT
// It should not be called by users.
void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer);
SerializeValue(&buffer, threshold_);
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, offset_);
}
};
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册