未验证 提交 1a6ce8b9 编写于 作者: Z Zhaolong Xing 提交者: GitHub

add swish split gelu plugin dynamic support (#23305)

test=develop
上级 2bb1b0e8
...@@ -19,6 +19,9 @@ namespace paddle { ...@@ -19,6 +19,9 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
/*
* Gelu converter from fluid to tensorRT.
*/
/* /*
* Gelu converter from fluid to tensorRT. * Gelu converter from fluid to tensorRT.
*/ */
...@@ -40,15 +43,21 @@ class GeluOpConverter : public OpConverter { ...@@ -40,15 +43,21 @@ class GeluOpConverter : public OpConverter {
PADDLE_ENFORCE_EQ(output_num, 1, PADDLE_ENFORCE_EQ(output_num, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"gelu op has only 1 output, but got %d", output_num)); "gelu op has only 1 output, but got %d", output_num));
// Get input shape and volume
nvinfer1::Dims input_shape = input->getDimensions(); nvinfer1::ILayer* layer = nullptr;
size_t input_volume = 1; if (engine_->with_dynamic_shape()) {
for (int i = 0; i < input_shape.nbDims; i++) { #if IS_TRT_VERSION_GE(6000)
input_volume *= input_shape.d[i]; plugin::GeluPluginDynamic* plugin = new plugin::GeluPluginDynamic();
layer = engine_->AddPluginV2(&input, input_num, plugin);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
} else {
plugin::GeluPlugin* plugin = new plugin::GeluPlugin();
layer = engine_->AddPlugin(&input, input_num, plugin);
} }
plugin::GeluPlugin* plugin = new plugin::GeluPlugin(input_volume);
nvinfer1::IPluginLayer* layer =
engine_->AddPlugin(&input, input_num, plugin);
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "gelu", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "gelu", {output_name}, test_mode);
} }
......
...@@ -37,25 +37,57 @@ class SplitOpConverter : public OpConverter { ...@@ -37,25 +37,57 @@ class SplitOpConverter : public OpConverter {
int axis = boost::get<int>(op_desc.GetAttr("axis")); int axis = boost::get<int>(op_desc.GetAttr("axis"));
// split on batch is not supported in TensorRT // split on batch is not supported in TensorRT
PADDLE_ENFORCE(axis != 0); PADDLE_ENFORCE(axis != 0);
axis += (axis < 0) ? input_dims.nbDims : -1;
std::vector<int> output_lengths = std::vector<int> output_lengths =
boost::get<std::vector<int>>(op_desc.GetAttr("sections")); boost::get<std::vector<int>>(op_desc.GetAttr("sections"));
output_lengths.reserve(output_num); int num = 0;
int num = boost::get<int>(op_desc.GetAttr("num")); if (op_desc.HasAttr("num")) {
num = boost::get<int>(op_desc.GetAttr("num"));
}
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
axis += (axis < 0) ? input_dims.nbDims : 0;
#endif
} else {
axis += (axis < 0) ? input_dims.nbDims : -1;
}
PADDLE_ENFORCE_NE(input_dims.d[axis], -1,
platform::errors::InvalidArgument(
"The (%d) dim of input should not be -1", axis));
if (num > 0) { if (num > 0) {
int64_t in_axis_dim = input_dims.d[axis]; int64_t in_axis_dim = input_dims.d[axis];
PADDLE_ENFORCE_EQ(in_axis_dim % num, 0, PADDLE_ENFORCE_EQ(in_axis_dim % num, 0,
"Tensor split does not result" "Tensor split does not result"
" in an equal division"); " in an equal division");
size_t out_axis_dim = in_axis_dim / num; size_t out_axis_dim = in_axis_dim / num;
for (size_t i = 0; i < output_num; ++i) { for (int i = 0; i < num; ++i) {
output_lengths.push_back(out_axis_dim); output_lengths.push_back(out_axis_dim);
} }
} }
PADDLE_ENFORCE(output_lengths.size() == output_num);
plugin::SplitPlugin* plugin = new plugin::SplitPlugin(axis, output_lengths); PADDLE_ENFORCE_EQ(
nvinfer1::IPluginLayer* layer = output_lengths.size(), output_num,
engine_->AddPlugin(&input, input_num, plugin); platform::errors::InvalidArgument(
"The output_length should be equal to the output size."));
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
plugin::SplitPluginDynamic* plugin =
new plugin::SplitPluginDynamic(axis, output_lengths);
layer = engine_->AddPluginV2(&input, input_num, plugin);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
} else {
plugin::SplitPlugin* plugin =
new plugin::SplitPlugin(axis, output_lengths);
layer = engine_->AddPlugin(&input, input_num, plugin);
}
std::string layer_name = "split (Output: "; std::string layer_name = "split (Output: ";
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
......
...@@ -36,10 +36,20 @@ class SwishOpConverter : public OpConverter { ...@@ -36,10 +36,20 @@ class SwishOpConverter : public OpConverter {
// Get attrs // Get attrs
float beta = boost::get<float>(op_desc.GetAttr("beta")); float beta = boost::get<float>(op_desc.GetAttr("beta"));
plugin::SwishPlugin* plugin = new plugin::SwishPlugin(beta); nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
nvinfer1::IPluginLayer* layer = #if IS_TRT_VERSION_GE(6000)
engine_->AddPlugin(&input, input_num, plugin); plugin::SwishPluginDynamic* plugin = new plugin::SwishPluginDynamic(beta);
layer = engine_->AddPluginV2(&input, input_num, plugin);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
} else {
plugin::SwishPlugin* plugin = new plugin::SwishPlugin(beta);
layer = engine_->AddPlugin(&input, input_num, plugin);
}
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "swish", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "swish", {output_name}, test_mode);
......
...@@ -148,6 +148,7 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -148,6 +148,7 @@ void TensorRTEngine::FreezeNetwork() {
if (with_dynamic_shape_) { if (with_dynamic_shape_) {
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
LOG(INFO) << "Run Paddle-TRT Dynamic Shape mode.";
for (auto &input : min_input_shape_) { for (auto &input : min_input_shape_) {
optim_profile_->setDimensions( optim_profile_->setDimensions(
input.first.c_str(), nvinfer1::OptProfileSelector::kMIN, input.first.c_str(), nvinfer1::OptProfileSelector::kMIN,
......
...@@ -24,12 +24,29 @@ namespace tensorrt { ...@@ -24,12 +24,29 @@ namespace tensorrt {
namespace plugin { namespace plugin {
// constants for approximating the normal cdf // constants for approximating the normal cdf
constexpr float A = 1.41421356237309504; // sqrt(2) static const float kA = 1.41421356237309504; // sqrt(2)
static const float kAT = 0.5;
static const float kBT = 0.7978845608028654; // sqrt(2.0/M_PI)
static const float kCT = 0.035677408136300125; // 0.044715 * sqrt(2.0/M_PI)
GeluPlugin* CreateGeluPluginDeserialize(const void* buffer, size_t length) { GeluPlugin* CreateGeluPluginDeserialize(const void* buffer, size_t length) {
return new GeluPlugin(buffer, length); return new GeluPlugin(buffer, length);
} }
REGISTER_TRT_PLUGIN("gelu plugin", CreateGeluPluginDeserialize);
REGISTER_TRT_PLUGIN("gelu_plugin", CreateGeluPluginDeserialize);
bool GeluPlugin::supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const {
#ifdef SUPPORTS_CUDA_FP16
return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kNCHW));
#else
return ((type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::PluginFormat::kNCHW));
#endif
}
nvinfer1::Dims GeluPlugin::getOutputDimensions(int index, nvinfer1::Dims GeluPlugin::getOutputDimensions(int index,
const nvinfer1::Dims* in_dims, const nvinfer1::Dims* in_dims,
...@@ -42,7 +59,7 @@ nvinfer1::Dims GeluPlugin::getOutputDimensions(int index, ...@@ -42,7 +59,7 @@ nvinfer1::Dims GeluPlugin::getOutputDimensions(int index,
} }
template <typename T, unsigned TPB> template <typename T, unsigned TPB>
__global__ void geluKernel(const T a, int n, const T* input, T* output) { __global__ void gelu_kernel(const T a, int n, const T* input, T* output) {
const int idx = blockIdx.x * TPB + threadIdx.x; const int idx = blockIdx.x * TPB + threadIdx.x;
if (idx < n) { if (idx < n) {
const T in = input[idx]; const T in = input[idx];
...@@ -51,24 +68,152 @@ __global__ void geluKernel(const T a, int n, const T* input, T* output) { ...@@ -51,24 +68,152 @@ __global__ void geluKernel(const T a, int n, const T* input, T* output) {
} }
} }
int computeGelu(cudaStream_t stream, int n, const float* input, float* output) { template <typename T>
constexpr int blockSize = 256; __device__ T do_tanh(T a);
const int gridSize = (n + blockSize - 1) / blockSize;
geluKernel<float, blockSize><<<gridSize, blockSize, 0, stream>>>(A, n, input, template <>
output); __device__ float do_tanh<float>(float a) {
cudaError_t error = cudaGetLastError(); return tanf(a);
if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error);
return 0;
} }
int GeluPlugin::enqueue(int batchSize, const void* const* inputs, template <>
__device__ half do_tanh<half>(half a) {
const float tmp = tanhf(__half2float(a));
return __float2half(tmp);
}
// the kernel below is not aligned with fluid fp32 forwrad ones, use it for
// fp16.
template <typename T, unsigned TPB>
__global__ void no_exact_gelu_kernel(const T a, const T b, const T c, int n,
const T* input, T* output) {
const int idx = blockIdx.x * TPB + threadIdx.x;
if (idx < n) {
const T in = input[idx];
const T tmp = in * (c * in * in + b);
const T cdf = a + a * do_tanh<T>(tmp);
output[idx] = in * cdf;
}
}
int GeluPlugin::enqueue(int batch_size, const void* const* inputs,
void** outputs, void*, cudaStream_t stream) { void** outputs, void*, cudaStream_t stream) {
int status = -1; const auto& input_dims = this->getInputDims(0);
const float* input = static_cast<const float*>(inputs[0]); int num = batch_size;
float* output = static_cast<float*>(outputs[0]); for (int i = 0; i < input_dims.nbDims; i++) {
status = computeGelu(stream, input_volume_ * batchSize, input, output); num *= input_dims.d[i];
return status; }
const int block_size = 256;
const int grid_size = (num + block_size - 1) / block_size;
auto type = getDataType();
if (type == nvinfer1::DataType::kFLOAT) {
const float* input = static_cast<const float*>(inputs[0]);
float* output = static_cast<float*>(outputs[0]);
gelu_kernel<float, block_size><<<grid_size, block_size, 0, stream>>>(
kA, num, input, output);
} else if (type == nvinfer1::DataType::kHALF) {
#ifdef SUPPORTS_CUDA_FP16
const half* input = static_cast<const half*>(inputs[0]);
half* output = static_cast<half*>(outputs[0]);
no_exact_gelu_kernel<half,
block_size><<<grid_size, block_size, 0, stream>>>(
kAT, kBT, kCT, num, input, output);
#else
PADDLE_THROW(platform::errors::Fatal(
"The cuda archs you specific should greater than 600."));
#endif
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The Gelu TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess;
}
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
size_t GeluPluginDynamic::getSerializationSize() const { return 0; }
void GeluPluginDynamic::serialize(void* buffer) const {}
nvinfer1::DimsExprs GeluPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) {
return inputs[0];
}
bool GeluPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
int nb_outputs) {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
(in_out && pos < (nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc& in = in_out[pos];
if (pos == 0) {
#ifdef SUPPORTS_CUDA_FP16
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
}
const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType GeluPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
"The Gelu Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
return input_types[0];
}
int GeluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs, void* const* outputs,
void* workspace, cudaStream_t stream) {
auto input_dims = input_desc[0].dims;
size_t num = ProductDim(input_dims);
const int block_size = 256;
const int grid_size = (num + block_size - 1) / block_size;
auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) {
const float* input = static_cast<const float*>(inputs[0]);
float* output = static_cast<float*>(outputs[0]);
gelu_kernel<float, block_size><<<grid_size, block_size, 0, stream>>>(
kA, num, input, output);
} else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef SUPPORTS_CUDA_FP16
const half* input = static_cast<const half*>(inputs[0]);
half* output = static_cast<half*>(outputs[0]);
no_exact_gelu_kernel<half,
block_size><<<grid_size, block_size, 0, stream>>>(
kAT, kBT, kCT, num, input, output);
#else
PADDLE_THROW(platform::errors::Fatal(
"The cuda archs you specific should greater than 600."));
#endif
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The Gelu TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess;
} }
#endif
} // namespace plugin } // namespace plugin
} // namespace tensorrt } // namespace tensorrt
......
...@@ -25,46 +25,90 @@ namespace tensorrt { ...@@ -25,46 +25,90 @@ namespace tensorrt {
namespace plugin { namespace plugin {
class GeluPlugin : public PluginTensorRT { class GeluPlugin : public PluginTensorRT {
public:
GeluPlugin() {}
// It was used for tensorrt deserialization.
// It should not be called by users.
GeluPlugin(void const* serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
}
~GeluPlugin() {}
GeluPlugin* clone() const override { return new GeluPlugin(); }
const char* getPluginType() const override { return "gelu_plugin"; }
int getNbOutputs() const override { return 1; }
int initialize() override { return 0; }
bool supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const override;
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
int nbInputDims) override;
int enqueue(int batchSize, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) override;
protected: protected:
size_t getSerializationSize() override { size_t getSerializationSize() override {
return getBaseSerializationSize() + SerializedSize(getPluginType()) + return getBaseSerializationSize() + SerializedSize(getPluginType());
SerializedSize(input_volume_);
} }
// TRT will call this func to serialize the configuration of TRT // TRT will call this func to serialize the configuration of TRT
// It should not be called by users. // It should not be called by users.
void serialize(void *buffer) override { void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType()); SerializeValue(&buffer, getPluginType());
serializeBase(buffer); serializeBase(buffer);
SerializeValue(&buffer, input_volume_);
} }
};
#if IS_TRT_VERSION_GE(6000)
class GeluPluginDynamic : public DynamicPluginTensorRT {
public: public:
explicit GeluPlugin(size_t input_volume) : input_volume_(input_volume) {} GeluPluginDynamic() {}
GeluPluginDynamic(void const* serialData, size_t serialLength) {}
// It was used for tensorrt deserialization. ~GeluPluginDynamic() {}
// It should not be called by users. nvinfer1::IPluginV2DynamicExt* clone() const override {
GeluPlugin(void const *serialData, size_t serialLength) { return new GeluPluginDynamic();
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &input_volume_);
} }
~GeluPlugin() {} const char* getPluginType() const override { return "gelu_plugin"; }
int getNbOutputs() const override { return 1; }
int initialize() override { return 0; } int initialize() override { return 0; }
GeluPlugin *clone() const override { return new GeluPlugin(input_volume_); } size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
const char *getPluginType() const override { return "gelu_plugin"; } nvinfer1::DimsExprs getOutputDimensions(
int getNbOutputs() const override { return 1; } int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, nvinfer1::IExprBuilder& exprBuilder) override;
int nbInputDims) override;
int enqueue(int batchSize, const void *const *inputs, void **outputs, bool supportsFormatCombination(int pos,
void *workspace, cudaStream_t stream) override; const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;
private: void destroy() override { delete this; }
size_t input_volume_;
}; };
#endif
} // namespace plugin } // namespace plugin
} // namespace tensorrt } // namespace tensorrt
......
...@@ -128,6 +128,144 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs, ...@@ -128,6 +128,144 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
int SplitPluginDynamic::initialize() { return 0; }
size_t SplitPluginDynamic::getSerializationSize() const { return 0; }
void SplitPluginDynamic::serialize(void* buffer) const {}
nvinfer1::DimsExprs SplitPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) {
PADDLE_ENFORCE_EQ(nb_inputs, 1,
platform::errors::InvalidArgument(
"The Split plugin should be only one input."));
PADDLE_ENFORCE_LT(output_index, output_length_.size(),
platform::errors::InvalidArgument(
"When GetOutputDimensions, the index(%d) should not "
"greater the num(%d) of the outpus.",
output_index, output_length_.size()));
nvinfer1::DimsExprs output_dims = inputs[0];
output_dims.d[axis_] = expr_builder.constant(output_length_.at(output_index));
return output_dims;
}
bool SplitPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
int nb_outputs) {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
(in_out && pos < (nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc& in = in_out[pos];
if (pos == 0) {
#ifdef SUPPORTS_CUDA_FP16
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
}
const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType SplitPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* input_types, int nb_inputs) const {
return input_types[0];
}
int SplitPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs, void* const* outputs,
void* workspace, cudaStream_t stream) {
auto input_dims = input_desc[0].dims;
int outer_rows = 1;
int inner_cols = 1;
// with batch
for (int i = 0; i < axis_; i++) {
outer_rows *= input_dims.d[i];
}
for (int i = axis_ + 1; i < input_dims.nbDims; i++) {
inner_cols *= input_dims.d[i];
}
std::vector<int> segment_offsets(1, 0);
for (int i = 0; i < this->getNbOutputs(); i++) {
segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
}
int axis_shape = input_dims.d[axis_];
thrust::device_vector<int> d_segment_offsets = segment_offsets;
const int* d_segment_offsets_ptr =
thrust::raw_pointer_cast(&d_segment_offsets[0]);
dim3 block(32, 16);
dim3 grid(std::min((inner_cols - 1) / block.x + 1, 65535u),
std::min((axis_shape - 1) / block.y + 1, 65535u),
std::min((outer_rows - 1) / block.z + 1, 65535u));
auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) {
thrust::device_vector<float*> d_output_ptrs;
d_output_ptrs.resize(this->getNbOutputs(), nullptr);
const float* input_ptr = static_cast<const float*>(inputs[0]);
float* const* h_odatas = reinterpret_cast<float* const*>(outputs);
float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs[0]);
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemcpyAsync(output_ptrs, h_odatas,
d_output_ptrs.size() * sizeof(float*),
cudaMemcpyHostToDevice, stream),
platform::errors::External(
"CUDA Memcpy failed during split plugin run."));
split_kernel<<<grid, block, 0, stream>>>(
d_segment_offsets.size(), d_segment_offsets_ptr, input_ptr, output_ptrs,
inner_cols, axis_shape, outer_rows);
} else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef SUPPORTS_CUDA_FP16
thrust::device_vector<half*> d_output_ptrs;
d_output_ptrs.resize(this->getNbOutputs(), nullptr);
const half* input_ptr = static_cast<const half*>(inputs[0]);
half* const* h_odatas = reinterpret_cast<half* const*>(outputs);
half** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs[0]);
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemcpyAsync(output_ptrs, h_odatas,
d_output_ptrs.size() * sizeof(half*),
cudaMemcpyHostToDevice, stream),
platform::errors::External(
"CUDA Memcpy failed during split plugin run."));
split_kernel<<<grid, block, 0, stream>>>(
d_segment_offsets.size(), d_segment_offsets_ptr, input_ptr, output_ptrs,
inner_cols, axis_shape, outer_rows);
#else
PADDLE_THROW(platform::errors::Fatal(
"The cuda archs you specific should greater than 600."));
#endif
}
return cudaGetLastError() != cudaSuccess;
}
#endif
} // namespace plugin } // namespace plugin
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
...@@ -27,28 +27,28 @@ namespace plugin { ...@@ -27,28 +27,28 @@ namespace plugin {
class SplitPlugin : public PluginTensorRT { class SplitPlugin : public PluginTensorRT {
public: public:
SplitPlugin() {} SplitPlugin() {}
SplitPlugin(int axis, std::vector<int> const &output_lengths) SplitPlugin(int axis, std::vector<int> const& output_lengths)
: axis_(axis), same_shape_(true), output_length_(output_lengths) {} : axis_(axis), same_shape_(true), output_length_(output_lengths) {}
SplitPlugin(void const *serial_data, size_t serial_length) { SplitPlugin(void const* serial_data, size_t serial_length) {
deserializeBase(serial_data, serial_length); deserializeBase(serial_data, serial_length);
DeserializeValue(&serial_data, &serial_length, &axis_); DeserializeValue(&serial_data, &serial_length, &axis_);
DeserializeValue(&serial_data, &serial_length, &output_length_); DeserializeValue(&serial_data, &serial_length, &output_length_);
} }
SplitPlugin *clone() const override { SplitPlugin* clone() const override {
return new SplitPlugin(axis_, output_length_); return new SplitPlugin(axis_, output_length_);
} }
const char *getPluginType() const override { return "split_plugin"; } const char* getPluginType() const override { return "split_plugin"; }
int getNbOutputs() const override { return output_length_.size(); } int getNbOutputs() const override { return output_length_.size(); }
nvinfer1::Dims getOutputDimensions(int index, nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims *input_dims, const nvinfer1::Dims* input_dims,
int num_inputs) override; int num_inputs) override;
int initialize() override; int initialize() override;
int enqueue(int batchSize, const void *const *inputs, void **outputs, int enqueue(int batchSize, const void* const* inputs, void** outputs,
void *workspace, cudaStream_t stream) override; void* workspace, cudaStream_t stream) override;
protected: protected:
size_t getSerializationSize() override { size_t getSerializationSize() override {
...@@ -56,7 +56,7 @@ class SplitPlugin : public PluginTensorRT { ...@@ -56,7 +56,7 @@ class SplitPlugin : public PluginTensorRT {
SerializedSize(output_length_) + getBaseSerializationSize(); SerializedSize(output_length_) + getBaseSerializationSize();
} }
void serialize(void *buffer) override { void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType()); SerializeValue(&buffer, getPluginType());
serializeBase(buffer); serializeBase(buffer);
SerializeValue(&buffer, axis_); SerializeValue(&buffer, axis_);
...@@ -71,9 +71,64 @@ class SplitPlugin : public PluginTensorRT { ...@@ -71,9 +71,64 @@ class SplitPlugin : public PluginTensorRT {
std::vector<int> output_length_; std::vector<int> output_length_;
std::vector<int> segment_offsets_; std::vector<int> segment_offsets_;
thrust::device_vector<int> d_segment_offsets_; thrust::device_vector<int> d_segment_offsets_;
thrust::device_vector<float *> d_output_ptrs_; thrust::device_vector<float*> d_output_ptrs_;
}; };
#if IS_TRT_VERSION_GE(6000)
class SplitPluginDynamic : public DynamicPluginTensorRT {
public:
SplitPluginDynamic(int axis, std::vector<int> const& output_lengths)
: axis_(axis), output_length_(output_lengths) {}
SplitPluginDynamic(void const* serial_data, size_t serial_length) {}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new SplitPluginDynamic(axis_, output_length_);
}
const char* getPluginType() const override { return "split_plugin"; }
int getNbOutputs() const override { return output_length_.size(); }
int initialize() override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;
void destroy() override { delete this; }
private:
int axis_;
std::vector<int> output_length_;
};
#endif
} // namespace plugin } // namespace plugin
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
...@@ -40,15 +40,33 @@ nvinfer1::Dims SwishPlugin::getOutputDimensions(int index, ...@@ -40,15 +40,33 @@ nvinfer1::Dims SwishPlugin::getOutputDimensions(int index,
nvinfer1::Dims output_dims = input_dims; nvinfer1::Dims output_dims = input_dims;
return output_dims; return output_dims;
} }
__global__ void swish_kernel(int num, const float *input, float *output,
float beta) { template <typename T>
__device__ T math_exp(T a);
#ifdef SUPPORTS_CUDA_FP16
template <>
__device__ half math_exp<half>(half a) {
return hexp(a);
}
#endif
template <>
__device__ float math_exp<float>(float a) {
return expf(a);
}
template <typename T>
__global__ void swish_kernel(int num, const T *input, T *output, T beta) {
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) { if (index < num) {
#if __CUDA_ARCH__ >= 350 #if __CUDA_ARCH__ >= 350
output[index] = output[index] =
__ldg(input + index) / (1.0f + expf(-beta * __ldg(input + index))); __ldg(input + index) /
(static_cast<T>(1.0) + math_exp<T>(-beta * __ldg(input + index)));
#else #else
output[index] = input[index] / (1.0f + expf(-beta * input[index])); output[index] = input[index] /
(static_cast<T>(1.0) + math_exp<T>(-beta * input[index]));
#endif #endif
} }
} }
...@@ -70,6 +88,97 @@ int SwishPlugin::enqueue(int batch_size, const void *const *inputs, ...@@ -70,6 +88,97 @@ int SwishPlugin::enqueue(int batch_size, const void *const *inputs,
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
int SwishPluginDynamic::initialize() {
setPluginNamespace("swish");
getPluginNamespace();
return 0;
}
size_t SwishPluginDynamic::getSerializationSize() const { return 0; }
void SwishPluginDynamic::serialize(void *buffer) const {}
nvinfer1::DimsExprs SwishPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) {
return inputs[0];
}
bool SwishPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
int nb_outputs) {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
(in_out && pos < (nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
#ifdef SUPPORTS_CUDA_FP16
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType SwishPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
"The Swish Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
return input_types[0];
}
int SwishPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc,
const void *const *inputs, void *const *outputs,
void *workspace, cudaStream_t stream) {
auto input_dims = input_desc[0].dims;
size_t num = ProductDim(input_dims);
int threads = 1024;
int blocks = (num + threads - 1) / threads;
auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) {
const float *input = static_cast<const float *>(inputs[0]);
float *output = static_cast<float *>(outputs[0]);
swish_kernel<float><<<blocks, threads, 0, stream>>>(num, input, output,
beta_);
} else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef SUPPORTS_CUDA_FP16
const half *input = static_cast<const half *>(inputs[0]);
half *output = static_cast<half *>(outputs[0]);
swish_kernel<half><<<blocks, threads, 0, stream>>>(
num, input, output, static_cast<half>(beta_));
#else
PADDLE_THROW(platform::errors::Fatal(
"The cuda archs you specific should greater than 600."));
#endif
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The Swish TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess;
}
#endif
} // namespace plugin } // namespace plugin
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
...@@ -38,7 +38,7 @@ class SwishPlugin : public PluginTensorRT { ...@@ -38,7 +38,7 @@ class SwishPlugin : public PluginTensorRT {
// TRT will call this func when we need to serialize the configuration of // TRT will call this func when we need to serialize the configuration of
// tensorrt. // tensorrt.
// It should not be called by users. // It should not be called by users.
void serialize(void *buffer) override { void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType()); SerializeValue(&buffer, getPluginType());
serializeBase(buffer); serializeBase(buffer);
SerializeValue(&buffer, beta_); SerializeValue(&buffer, beta_);
...@@ -49,23 +49,74 @@ class SwishPlugin : public PluginTensorRT { ...@@ -49,23 +49,74 @@ class SwishPlugin : public PluginTensorRT {
// It was used for tensorrt deserialization. // It was used for tensorrt deserialization.
// It should not be called by users. // It should not be called by users.
SwishPlugin(void const *serialData, size_t serialLength) { SwishPlugin(void const* serialData, size_t serialLength) {
deserializeBase(serialData, serialLength); deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &beta_); DeserializeValue(&serialData, &serialLength, &beta_);
} }
~SwishPlugin() {} ~SwishPlugin() {}
int initialize() override; int initialize() override;
SwishPlugin *clone() const override { return new SwishPlugin(beta_); } SwishPlugin* clone() const override { return new SwishPlugin(beta_); }
const char *getPluginType() const override { return "swish_plugin"; } const char* getPluginType() const override { return "swish_plugin"; }
int getNbOutputs() const override { return 1; } int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
int nbInputDims) override; int nbInputDims) override;
int enqueue(int batchSize, const void *const *inputs, void **outputs, int enqueue(int batchSize, const void* const* inputs, void** outputs,
void *workspace, cudaStream_t stream) override; void* workspace, cudaStream_t stream) override;
}; };
#if IS_TRT_VERSION_GE(6000)
class SwishPluginDynamic : public DynamicPluginTensorRT {
public:
explicit SwishPluginDynamic(const float beta) : beta_(beta) {}
SwishPluginDynamic(void const* serialData, size_t serialLength) {}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new SwishPluginDynamic(beta_);
}
const char* getPluginType() const override { return "swish_plugin"; }
int getNbOutputs() const override { return 1; }
int initialize() override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
nvinfer1::DimsExprs getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;
void destroy() override { delete this; }
private:
float beta_;
};
#endif
} // namespace plugin } // namespace plugin
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
...@@ -373,9 +373,9 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -373,9 +373,9 @@ if(WITH_GPU AND TENSORRT_FOUND)
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_QUANT_RESNET_DIR}) ARGS --infer_model=${TRT_MODEL_QUANT_RESNET_DIR})
set(TEST_TRT_DYNAMIC_MODEL "${TRT_MODEL_INSTALL_DIR}/test_trt_dy_conv") set(TEST_TRT_DYNAMIC_MODEL "${TRT_MODEL_INSTALL_DIR}/conv_bn_swish_split_gelu")
if (NOT EXISTS ${TEST_TRT_DYNAMIC_MODEL}) if (NOT EXISTS ${TEST_TRT_DYNAMIC_MODEL})
inference_download_and_uncompress(${TEST_TRT_DYNAMIC_MODEL} ${INFERENCE_URL}/tensorrt_test "test_trt_dy_conv.tar.gz") inference_download_and_uncompress(${TEST_TRT_DYNAMIC_MODEL} ${INFERENCE_URL}/tensorrt_test "conv_bn_swish_split_gelu.tar.gz")
endif() endif()
inference_analysis_test(trt_dynamic_shape_test SRCS trt_dynamic_shape_test.cc inference_analysis_test(trt_dynamic_shape_test SRCS trt_dynamic_shape_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
......
...@@ -21,24 +21,27 @@ limitations under the License. */ ...@@ -21,24 +21,27 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace inference { namespace inference {
TEST(AnalysisPredictor, use_gpu) { void TestDynamic(bool with_dynamic = true) {
std::string model_dir = FLAGS_infer_model + "/test_trt_dy_conv"; std::string model_dir = FLAGS_infer_model + "/conv_bn_swish_split_gelu";
AnalysisConfig config; AnalysisConfig config;
config.EnableUseGpu(100, 0); config.EnableUseGpu(100, 0);
config.SetModel(model_dir); config.SetModel(model_dir + "/model", model_dir + "/params");
config.SwitchUseFeedFetchOps(false); config.SwitchUseFeedFetchOps(false);
// Set the input's min, max, opt shape // Set the input's min, max, opt shape
std::map<std::string, std::vector<int>> min_input_shape = {
{"image", {1, 1, 3, 3}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"image", {1, 1, 10, 10}}};
std::map<std::string, std::vector<int>> opt_input_shape = {
{"image", {1, 1, 3, 3}}};
config.EnableTensorRtEngine(1 << 30, 1, 1, config.EnableTensorRtEngine(1 << 30, 1, 1,
AnalysisConfig::Precision::kFloat32, false, true); AnalysisConfig::Precision::kFloat32, false, true);
if (with_dynamic) {
std::map<std::string, std::vector<int>> min_input_shape = {
{"image", {1, 1, 3, 3}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"image", {1, 1, 10, 10}}};
std::map<std::string, std::vector<int>> opt_input_shape = {
{"image", {1, 1, 3, 3}}};
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape); opt_input_shape);
}
auto predictor = CreatePaddlePredictor(config); auto predictor = CreatePaddlePredictor(config);
auto input_names = predictor->GetInputNames(); auto input_names = predictor->GetInputNames();
int channels = 1; int channels = 1;
...@@ -64,5 +67,8 @@ TEST(AnalysisPredictor, use_gpu) { ...@@ -64,5 +67,8 @@ TEST(AnalysisPredictor, use_gpu) {
output_t->copy_to_cpu(out_data.data()); output_t->copy_to_cpu(out_data.data());
} }
TEST(AnalysisPredictor, trt_dynamic) { TestDynamic(true); }
TEST(AnalysisPredictor, trt_static) { TestDynamic(false); }
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册