diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 94b5400a470d9233777d97e700afe007237c120b..1639df7d4b061927a3ecfeec03f6969c661e964f 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -973,4 +973,5 @@ USE_TRT_CONVERTER(gelu); USE_TRT_CONVERTER(multihead_matmul); USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm); USE_TRT_CONVERTER(skip_layernorm); +USE_TRT_CONVERTER(scale); #endif diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index dacea1ebcb2efa5d74d8f1b37f279fb40bda6f5e..a5989bedd885e1fc5c6b7f68478e996b39aca0b6 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -3,7 +3,8 @@ nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc - shuffle_channel_op.cc swish_op.cc instance_norm_op.cc emb_eltwise_layernorm.cc skip_layernorm.cc + shuffle_channel_op.cc swish_op.cc instance_norm_op.cc +emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/convert/concat_op.cc b/paddle/fluid/inference/tensorrt/convert/concat_op.cc index ec771850edf5f4f0207fb664e26b2d9b98a7a128..1793920f0a9ca8518722c30781b4b2781584de46 100644 --- a/paddle/fluid/inference/tensorrt/convert/concat_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/concat_op.cc @@ -39,7 +39,9 @@ class ConcatOpConverter : public OpConverter { auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Concatenation, itensors.data(), itensors.size()); - axis = axis - 1; // Remove batch dim + if (!engine_->with_dynamic_shape()) { + axis = axis - 1; // Remove batch dim + } layer->setAxis(axis); auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "concat", {output_name}, test_mode); diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 4ae2f91d1a673d41b256113475982c1060518ad8..1c0deda525a8aec6ae347ac9e34537effabf308d 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -48,20 +48,68 @@ class ElementwiseWeightOpConverter : public OpConverter { PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); auto* X = engine_->GetITensor(op_desc.Input("X").front()); - nvinfer1::Dims dims_x = X->getDimensions(); - std::vector no_batch_dims; - int start_index = 0; - - if (engine_->with_dynamic_shape()) start_index = 1; - for (; start_index < dims_x.nbDims; start_index++) - no_batch_dims.push_back(dims_x.d[start_index]); - auto* Y_v = scope.FindVar(op_desc.Input("Y").front()); PADDLE_ENFORCE_NOT_NULL(Y_v); auto* Y_t = Y_v->GetMutable(); float* weight_data = nullptr; weight_data = engine_->GetWeightCPUData(op_desc.Input("Y").front(), Y_t, false); + nvinfer1::Dims dims_x = X->getDimensions(); + + auto regist_eltwise_weight = [&](nvinfer1::ScaleMode scale_mode) { + TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT, + static_cast(weight_data), + static_cast(Y_t->numel())}; + TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, nullptr, + 0}; + TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, + 0}; + if (op_type_ == "add") { + nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *X, scale_mode, shift_weights.get(), + scale_weights.get(), power_weights.get()); + layer = scale_layer; + } else if (op_type_ == "mul") { + nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *X, scale_mode, scale_weights.get(), + shift_weights.get(), power_weights.get()); + layer = scale_layer; + } + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "elementwise_" + op_type_, {output_name}, + test_mode); + if (op_desc.HasAttr("enable_int8")) { +#if IS_TRT_VERSION_GE(5000) + CHECK(op_desc.HasAttr("X_scale")); + float x_scale = boost::get(op_desc.GetAttr("X_scale")); + engine_->SetTensorDynamicRange(X, x_scale); +#endif + } + }; + + if (engine_->with_dynamic_shape()) { + if (Y_t->dims().size() == 1) { + auto scale_mode = nvinfer1::ScaleMode::kCHANNEL; + PADDLE_ENFORCE_EQ(Y_t->dims()[0], dims_x.d[1], + platform::errors::InvalidArgument( + "The Bias's size(%d) should be equal to the " + "first dim(%d) of the Input.", + Y_t->dims()[0], dims_x.d[1])); + regist_eltwise_weight(scale_mode); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "TensorRT Dynamic shape unsupported weight shape for Elementwise " + "op!")); + } + return; + } + + std::vector no_batch_dims; + int start_index = 0; + + for (; start_index < dims_x.nbDims; start_index++) + no_batch_dims.push_back(dims_x.d[start_index]); auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; @@ -91,36 +139,7 @@ class ElementwiseWeightOpConverter : public OpConverter { } else { PADDLE_THROW("TensorRT unsupported weight Shape for Elementwise op!"); } - - TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT, - static_cast(weight_data), - static_cast(Y_t->numel())}; - TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, nullptr, - 0}; - TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, - 0}; - if (op_type_ == "add") { - nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( - engine_, Scale, *X, scale_mode, shift_weights.get(), - scale_weights.get(), power_weights.get()); - layer = scale_layer; - } else if (op_type_ == "mul") { - nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( - engine_, Scale, *X, scale_mode, scale_weights.get(), - shift_weights.get(), power_weights.get()); - layer = scale_layer; - } - - auto output_name = op_desc.Output("Out")[0]; - RreplenishLayerAndOutput(layer, "elementwise_" + op_type_, {output_name}, - test_mode); - if (op_desc.HasAttr("enable_int8")) { -#if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr("X_scale")); - float x_scale = boost::get(op_desc.GetAttr("X_scale")); - engine_->SetTensorDynamicRange(X, x_scale); -#endif - } + regist_eltwise_weight(scale_mode); } protected: @@ -146,44 +165,62 @@ class ElementwiseTensorOpConverter : public OpConverter { auto* X = engine_->GetITensor(op_desc.Input("X").front()); auto* Y = engine_->GetITensor(op_desc.Input("Y").front()); + std::vector itensors; + itensors.push_back(X); + itensors.push_back(Y); nvinfer1::Dims dims_x = X->getDimensions(); nvinfer1::Dims dims_y = Y->getDimensions(); int axis = boost::get(op_desc.GetAttr("axis")); auto output_name = op_desc.Output("Out")[0]; + + auto common_func = [&](nvinfer1::ILayer* layer) { + RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode); + if (op_desc.HasAttr("enable_int8")) { +#if IS_TRT_VERSION_GE(5000) + CHECK(op_desc.HasAttr("X_scale")); + CHECK(op_desc.HasAttr("Y_scale")); + float x_scale = boost::get(op_desc.GetAttr("X_scale")); + float y_scale = boost::get(op_desc.GetAttr("Y_scale")); + engine_->SetTensorDynamicRange(X, x_scale); + engine_->SetTensorDynamicRange(Y, y_scale); +#endif + } + }; + if (CheckDims(dims_x, dims_y)) { // The two input tensor should have the same dims VLOG(3) << "Convert a fluid elementwise op to TensorRT IElementWiseLayer"; - nvinfer1::IElementWiseLayer* elet_layer = TRT_ENGINE_ADD_LAYER( - engine_, ElementWise, *const_cast(X), - *const_cast(Y), op_pair->second); + nvinfer1::IElementWiseLayer* elet_layer = + TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *X, *Y, op_pair->second); layer = elet_layer; } else { VLOG(3) << "Convert a fluid elementwise op to TensorRT " "ElementWisePluginLayer"; - - plugin::ElementWisePlugin* plugin = - new plugin::ElementWisePlugin(op_type_, dims_x, dims_y, axis); - plugin->AddInput(X); - plugin->AddInput(Y); - nvinfer1::IPluginLayer* plugin_layer = engine_->AddPlugin( - const_cast(plugin->GetInputs().data()), 2, - reinterpret_cast(plugin)); - - layer = plugin_layer; - } - RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode); - if (op_desc.HasAttr("enable_int8")) { -#if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr("X_scale")); - CHECK(op_desc.HasAttr("Y_scale")); - float x_scale = boost::get(op_desc.GetAttr("X_scale")); - float y_scale = boost::get(op_desc.GetAttr("Y_scale")); - engine_->SetTensorDynamicRange(X, x_scale); - engine_->SetTensorDynamicRange(Y, y_scale); + if (engine_->with_dynamic_shape()) { +#if IS_TRT_VERSION_GE(6000) + plugin::ElementwisePluginDynamic* plugin = + new plugin::ElementwisePluginDynamic(op_type_, axis); + layer = engine_->AddPluginV2(itensors.data(), 2, plugin); +#else + PADDLE_THROW(platform::errors::Fatal( + "You are running the TRT Dynamic Shape mode, need to confirm that " + "your TRT version is no less than 6.0")); #endif + } else { + plugin::ElementWisePlugin* plugin = + new plugin::ElementWisePlugin(op_type_, dims_x, dims_y, axis); + plugin->AddInput(X); + plugin->AddInput(Y); + nvinfer1::IPluginLayer* plugin_layer = engine_->AddPlugin( + plugin->GetInputs().data(), 2, + reinterpret_cast(plugin)); + + layer = plugin_layer; + } } + common_func(layer); } protected: diff --git a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc index d76057c731468e8a0c15b0251dd535841f1725be..e7605cceb7b45b95d6bd81f4bf69a9fdb0d7e276 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc @@ -19,10 +19,10 @@ namespace paddle { namespace inference { namespace tensorrt { -void DealCeilMode(const nvinfer1::Dims &input_shape, std::vector ksize, - std::vector strides, std::vector paddings, - nvinfer1::DimsHW *pre_pad, nvinfer1::DimsHW *post_pad, - int input_dims) { +inline void DealCeilMode(const nvinfer1::Dims &input_shape, + std::vector ksize, std::vector strides, + std::vector paddings, nvinfer1::DimsHW *pre_pad, + nvinfer1::DimsHW *post_pad, int input_dims) { int input_height = input_shape.d[input_dims - 2]; int input_width = input_shape.d[input_dims - 1]; int floor_h_output_size = @@ -112,6 +112,31 @@ class Pool2dOpConverter : public OpConverter { #endif } + if (engine_->with_dynamic_shape()) { + if (!adaptive && pool_type == "max" && !global_pooling && !ceil_mode) { + auto *pool_layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, *input1, + nv_pool_type, nv_ksize); + pool_layer->setStride(nv_strides); + pool_layer->setPadding(nv_paddings); + layer = pool_layer; + } else { +#if IS_TRT_VERSION_GE(6000) + plugin::PoolPluginDynamic *plugin = + new plugin::PoolPluginDynamic(ceil_mode, pool_type, adaptive, ksize, + strides, paddings, global_pooling); + layer = engine_->AddPluginV2(&input1, 1, plugin); +#endif + } + auto output_name = op_desc.Output("Out")[0]; + layer->setName(("pool2d (Output: " + output_name + ")").c_str()); + layer->getOutput(0)->setName(output_name.c_str()); + engine_->SetITensor(output_name, layer->getOutput(0)); + if (test_mode) { + engine_->DeclareOutput(output_name); + } + return; + } + if (global_pooling == true) { nv_ksize.d[0] = input_shape.d[input_dims - 2]; nv_ksize.d[1] = input_shape.d[input_dims - 1]; diff --git a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc index d327a743662aa5169901846e40232d593a158499..88dd1e0b5247a51d393ed334c5f9f7e7b944bc40 100644 --- a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc @@ -50,10 +50,22 @@ class PReluOpConverter : public OpConverter { TensorCopySync(*alpha_tensor, cpu_place, alpha_tensor_temp.get()); float* alpha_data = alpha_tensor_temp->mutable_data(cpu_place); - plugin::PReluPlugin* plugin = - new plugin::PReluPlugin(alpha_data, alpha_tensor_temp->numel(), mode); - nvinfer1::IPluginLayer* layer = - engine_->AddPlugin(&input, input_num, plugin); + nvinfer1::ILayer* layer = nullptr; + if (engine_->with_dynamic_shape()) { +#if IS_TRT_VERSION_GE(6000) + plugin::PReluPluginDynamic* plugin = new plugin::PReluPluginDynamic( + alpha_data, alpha_tensor_temp->numel(), mode); + layer = engine_->AddPluginV2(&input, input_num, plugin); +#else + PADDLE_THROW(platform::errors::Fatal( + "You are running the TRT Dynamic Shape mode, need to confirm that " + "your TRT version is no less than 6.0")); +#endif + } else { + plugin::PReluPlugin* plugin = + new plugin::PReluPlugin(alpha_data, alpha_tensor_temp->numel(), mode); + layer = engine_->AddPlugin(&input, input_num, plugin); + } // keep alpha tensor to avoid release it's memory engine_->SetWeights(op_desc.Input("Alpha")[0], std::move(alpha_tensor_temp)); diff --git a/paddle/fluid/inference/tensorrt/convert/scale_op.cc b/paddle/fluid/inference/tensorrt/convert/scale_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..21513105b27f39412d27090e9c95c68d1b985d38 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/scale_op.cc @@ -0,0 +1,84 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * ConcatOp + */ +class ScaleOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a fluid scale op to tensorrt mul layer without bias"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + std::vector itensors; + std::string input_name = op_desc.Input("X").front(); + std::string out_name = op_desc.Output("Out").front(); + + auto input = engine_->GetITensor(input_name); + bool bias_after_scale = + boost::get(op_desc.GetAttr("bias_after_scale")); + float bias = boost::get(op_desc.GetAttr("bias")); + float scale = boost::get(op_desc.GetAttr("scale")); + auto create_weights = [&](float data, std::string type) -> float* { + std::unique_ptr tmp_tensor(new framework::Tensor()); + tmp_tensor->Resize({1}); + auto* tmp_data = tmp_tensor->mutable_data(platform::CPUPlace()); + tmp_data[0] = data; + engine_->SetWeights(out_name + "_scale_op_" + type, + std::move(tmp_tensor)); + return tmp_data; + }; + + float* bias_ptr = create_weights(bias, "bias"); + float* scale_ptr = create_weights(scale, "scale"); + + TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, + static_cast(scale_ptr), 1}; + TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT, + static_cast(bias_ptr), 1}; + TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, + 0}; + nvinfer1::ILayer* layer = nullptr; + if (bias_after_scale) { + layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *input, nvinfer1::ScaleMode::kUNIFORM, + shift_weights.get(), scale_weights.get(), power_weights.get()); + } else { + // add bias + layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *(input), nvinfer1::ScaleMode::kUNIFORM, + shift_weights.get(), power_weights.get(), power_weights.get()); + // mul scale + layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *(layer->getOutput(0)), nvinfer1::ScaleMode::kUNIFORM, + power_weights.get(), scale_weights.get(), power_weights.get()); + } + + RreplenishLayerAndOutput(layer, "scale", {out_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(scale, ScaleOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/shuffle_channel_op.cc b/paddle/fluid/inference/tensorrt/convert/shuffle_channel_op.cc index 471f39597477575c69dfa72bfc2159b3e6520723..5e7fad6132197079b7e4a46cc2c807ad801cf6b3 100644 --- a/paddle/fluid/inference/tensorrt/convert/shuffle_channel_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/shuffle_channel_op.cc @@ -40,6 +40,12 @@ class ShuffleChannelOpConverter : public OpConverter { int w = input_dims.d[2]; int group = boost::get(op_desc.GetAttr("group")); + if (engine_->with_dynamic_shape()) { + PADDLE_THROW(platform::errors::Fatal( + "You are running the TRT Dynamic Shape mode, " + "the shuffle_channel op does not support dynamic shape yet")); + } + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); nvinfer1::Dims4 reshape_dim(group, c / group, h, w); layer->setReshapeDimensions(reshape_dim); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index facd473dd396b6221fe2e362c486d5222dd561c4..26eb26926fa9ffd0f45af2541059f55bf909e9a7 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -72,6 +72,7 @@ struct SimpleOpTypeSetTeller : public Teller { "instance_norm", "gelu", "layer_norm", + "scale", }; }; diff --git a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu index 9aed3ddab1448fde7cb6b0e13bcf0b05e23622e9..0ec803fe64afadd970777e3b0d0ab5d37fcc4d22 100644 --- a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu @@ -21,64 +21,41 @@ namespace inference { namespace tensorrt { namespace plugin { -ElementWisePlugin* CreateElementWisePluginDeserialize(const void* buffer, +ElementWisePlugin *CreateElementWisePluginDeserialize(const void *buffer, size_t length) { return new ElementWisePlugin(buffer, length); } REGISTER_TRT_PLUGIN("elementwise_plugin", CreateElementWisePluginDeserialize); namespace details { - template struct Add { - __device__ T operator()(const T& a, const T& b) const { return a + b; } + __device__ T operator()(const T &a, const T &b) const { return a + b; } }; template struct Mul { - __device__ T operator()(const T& a, const T& b) const { return a * b; } + __device__ T operator()(const T &a, const T &b) const { return a * b; } }; +} // namespace details template -__global__ void ColumnWiseKernel(Operator op, const T* x, const T* y, T* out, - int batch_size, int num_rows, int num_cols) { - for (int batch_id = 0; batch_id < batch_size; ++batch_id) { - int row = blockIdx.x; - for (; row < num_rows; row += gridDim.x) { - T value_y = y[batch_id * num_rows + row]; - int col = threadIdx.x; - int offset = (batch_id * num_rows + row) * num_cols; - for (; col < num_cols; col += blockDim.x) { - T value_x = x[offset + col]; - out[offset + col] = op(value_x, value_y); - } - } - } -} - -template -static void ElementWise(Operator op, const T* x, const T* y, T* out, - int batch_size, int prev, int midd, int post, - cudaStream_t stream) { - const int kThreadsPerBlock = 1024; - const int kMaximumBlocks = 65535; - if (prev == 1) { - int num_threads = (post > kThreadsPerBlock) ? kThreadsPerBlock - : (((post + 31) >> 5) << 5); - int num_blocks = (midd < kMaximumBlocks) ? midd : kMaximumBlocks; - ColumnWiseKernel<<>>( - op, x, y, out, batch_size, midd, post); - } else if (post == 1) { - PADDLE_THROW("Not implemented."); - } else { - PADDLE_THROW("Not implemented."); +__global__ void elementwise_kernel(const size_t total, const T *x_data, + const T *y_data, T *out_data, int pre, int n, + int post, Operator op) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < total) { + int idx = tid / post % n; +#if __CUDA_ARCH__ >= 350 + out_data[tid] = op(__ldg(x_data + tid), __ldg(y_data + idx)); +#else + out_data[tid] = op(x_data[tid], y_data[idx]); +#endif } } -} // namespace details - nvinfer1::Dims ElementWisePlugin::getOutputDimensions( - int index, const nvinfer1::Dims* input_dims, int num_inputs) { + int index, const nvinfer1::Dims *input_dims, int num_inputs) { PADDLE_ENFORCE_EQ(index, 0); PADDLE_ENFORCE_EQ(num_inputs, 2); PADDLE_ENFORCE_NOT_NULL(input_dims); @@ -119,25 +96,137 @@ int ElementWisePlugin::initialize() { return 0; } -int ElementWisePlugin::enqueue(int batch_size, const void* const* inputs, - void** outputs, void* workspace, +int ElementWisePlugin::enqueue(int batch_size, const void *const *inputs, + void **outputs, void *workspace, cudaStream_t stream) { - const float* x = reinterpret_cast(inputs[0]); - const float* y = reinterpret_cast(inputs[1]); - float* out = reinterpret_cast(outputs[0]); + const float *x = reinterpret_cast(inputs[0]); + const float *y = reinterpret_cast(inputs[1]); + float *out = reinterpret_cast(outputs[0]); + + int num = batch_size * prev_size_ * midd_size_ * post_size_; + int thread = 256; + int block = (num + thread - 1) / thread; + if (type_ == "add") { + elementwise_kernel<<>>( + num, x, y, out, prev_size_, batch_size * midd_size_, post_size_, + details::Add()); + } else if (type_ == "mul") { + elementwise_kernel<<>>( + num, x, y, out, prev_size_, batch_size * midd_size_, post_size_, + details::Mul()); + } else { + PADDLE_THROW(platform::errors::Fatal( + "The %s type elementwise is not implemented in trt plugin.", type_)); + } + + return cudaGetLastError() != cudaSuccess; +} + +// Dynamic Plugin below. +#if IS_TRT_VERSION_GE(6000) + +int ElementwisePluginDynamic::initialize() { return 0; } + +size_t ElementwisePluginDynamic::getSerializationSize() const { return 0; } + +void ElementwisePluginDynamic::serialize(void *buffer) const {} + +nvinfer1::DimsExprs ElementwisePluginDynamic::getOutputDimensions( + int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs, + nvinfer1::IExprBuilder &expr_builder) { + return inputs[0]; +} + +bool ElementwisePluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs, + int nb_outputs) { + PADDLE_ENFORCE_NOT_NULL( + in_out, platform::errors::InvalidArgument( + "The input of swish plugin shoule not be nullptr.")); + + PADDLE_ENFORCE_LT( + pos, nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, nb_inputs + nb_outputs)); + (in_out && pos < (nb_inputs + nb_outputs)); + + const nvinfer1::PluginTensorDesc &in = in_out[pos]; + if (pos == 0) { + return (in.type == nvinfer1::DataType::kFLOAT) && + (in.format == nvinfer1::TensorFormat::kLINEAR); + } + const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; + // output + return in.type == prev.type && in.format == prev.format; +} + +nvinfer1::DataType ElementwisePluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType *input_types, int nb_inputs) const { + PADDLE_ENFORCE_EQ(index, 0, + platform::errors::InvalidArgument( + "The Elementwise Plugin only has one input, so the " + "index value should be 0, but get %d.", + index)); + return input_types[0]; +} + +int ElementwisePluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc *input_desc, + const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs, + void *const *outputs, void *workspace, cudaStream_t stream) { + auto x_dims = input_desc[0].dims; + auto y_dims = input_desc[1].dims; + int axis = (axis_ == -1) ? x_dims.nbDims - y_dims.nbDims : axis_; + int batch_size = x_dims.d[0]; + + int prev_size = 1; + int midd_size = 1; + int post_size = 1; + for (int i = 0; i < axis; ++i) { + prev_size *= x_dims.d[i]; + } + + int trimed_nb_dims = y_dims.nbDims; + for (; trimed_nb_dims > 0; --trimed_nb_dims) { + if (y_dims.d[trimed_nb_dims - 1] != 1) { + break; + } + } + + for (int i = 0; i < trimed_nb_dims; ++i) { + PADDLE_ENFORCE_EQ(x_dims.d[i + axis], y_dims.d[i], + platform::errors::InvalidArgument( + "Broadcast dimension mismatch found in trt " + "elementwise plugin's x and y input.")); + midd_size *= y_dims.d[i]; + } + + for (int i = axis + trimed_nb_dims; i < x_dims.nbDims; ++i) { + post_size *= x_dims.d[i]; + } + + const float *x = static_cast(inputs[0]); + const float *y = static_cast(inputs[1]); + + float *out = static_cast(outputs[0]); + int num = prev_size * midd_size * post_size; + int thread = 256; + int block = (num + thread - 1) / thread; if (type_ == "add") { - details::ElementWise(details::Add(), x, y, out, batch_size, - prev_size_, midd_size_, post_size_, stream); + elementwise_kernel<<>>( + num, x, y, out, prev_size, midd_size, post_size, details::Add()); } else if (type_ == "mul") { - details::ElementWise(details::Mul(), x, y, out, batch_size, - prev_size_, midd_size_, post_size_, stream); + elementwise_kernel<<>>( + num, x, y, out, prev_size, midd_size, post_size, details::Mul()); } else { PADDLE_THROW("Not implemented."); } return cudaGetLastError() != cudaSuccess; } +#endif } // namespace plugin } // namespace tensorrt diff --git a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h index 3b040f14c531c540b8a855da85ecc3008224526c..e37511868d88f600a733df4ebb478e74a385be1b 100644 --- a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h @@ -25,8 +25,8 @@ namespace plugin { class ElementWisePlugin : public PluginTensorRT { public: - ElementWisePlugin(std::string type, nvinfer1::Dims const &dims_x, - nvinfer1::Dims const &dims_y, int axis) + ElementWisePlugin(std::string type, nvinfer1::Dims const& dims_x, + nvinfer1::Dims const& dims_y, int axis) : type_(type), dims_x_(dims_x), dims_y_(dims_y), @@ -35,9 +35,9 @@ class ElementWisePlugin : public PluginTensorRT { midd_size_(1), post_size_(1) {} - ElementWisePlugin(void const *serial_data, size_t serial_length) { + ElementWisePlugin(void const* serial_data, size_t serial_length) { deserializeBase(serial_data, serial_length); - const char *elementwise_type; + const char* elementwise_type; DeserializeValue(&serial_data, &serial_length, &elementwise_type); type_ = std::string(elementwise_type); DeserializeValue(&serial_data, &serial_length, &axis_); @@ -45,22 +45,22 @@ class ElementWisePlugin : public PluginTensorRT { DeserializeValue(&serial_data, &serial_length, &dims_y_); } - ElementWisePlugin *clone() const override { + ElementWisePlugin* clone() const override { // return new ElementWisePlugin(dims_x_, dims_y_, axis_); return nullptr; } - const char *getPluginType() const override { return "elementwise_plugin"; } + const char* getPluginType() const override { return "elementwise_plugin"; } nvinfer1::Dims getOutputDimensions(int index, - const nvinfer1::Dims *input_dims, + const nvinfer1::Dims* input_dims, int num_inputs) override; int initialize() override; // execute the layer - int enqueue(int batch_size, const void *const *inputs, void **outputs, - void *workspace, cudaStream_t stream); + int enqueue(int batch_size, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream); protected: size_t getSerializationSize() override { @@ -69,7 +69,7 @@ class ElementWisePlugin : public PluginTensorRT { getBaseSerializationSize(); } - void serialize(void *buffer) override { + void serialize(void* buffer) override { SerializeValue(&buffer, getPluginType()); serializeBase(buffer); SerializeValue(&buffer, type_.c_str()); @@ -87,6 +87,59 @@ class ElementWisePlugin : public PluginTensorRT { int post_size_; }; +#if IS_TRT_VERSION_GE(6000) +class ElementwisePluginDynamic : public DynamicPluginTensorRT { + public: + explicit ElementwisePluginDynamic(const std::string& type, int axis) + : type_(type), axis_(axis) {} + ElementwisePluginDynamic(void const* serialData, size_t serialLength) {} + nvinfer1::IPluginV2DynamicExt* clone() const override { + return new ElementwisePluginDynamic(type_, axis_); + } + + const char* getPluginType() const override { return "elementwise_plugin"; } + int getNbOutputs() const override { return 1; } + int initialize() override; + + size_t getSerializationSize() const override; + void serialize(void* buffer) const override; + + nvinfer1::DimsExprs getOutputDimensions( + int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, int nbOutputs) override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) override {} + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const override { + return 0; + } + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const override; + + void destroy() override { delete this; } + + private: + std::string type_; + int axis_; +}; +#endif + } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu index 6a718d47b1542b3cce97f6ff1f8744b4d58a8102..30f1c37ab18533c85252a415d76406a3d52a45d1 100644 --- a/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu @@ -194,8 +194,9 @@ int GeluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc, if (input_type == nvinfer1::DataType::kFLOAT) { const float* input = static_cast(inputs[0]); float* output = static_cast(outputs[0]); - gelu_kernel<<>>( - kA, num, input, output); + no_exact_gelu_kernel<<>>( + kAT, kBT, kCT, num, input, output); } else if (input_type == nvinfer1::DataType::kHALF) { #ifdef SUPPORTS_CUDA_FP16 const half* input = static_cast(inputs[0]); diff --git a/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu index 17904a4ebcdb338ac253623b09f61f42119aacfd..48afcfce347d681fbbb291e478ead1fa28475a22 100644 --- a/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu @@ -21,18 +21,18 @@ namespace inference { namespace tensorrt { namespace plugin { -PoolPlugin* CreatePoolPluginDeserialize(const void* buffer, size_t length) { +PoolPlugin *CreatePoolPluginDeserialize(const void *buffer, size_t length) { return new PoolPlugin(buffer, length); } REGISTER_TRT_PLUGIN("pool_plugin", CreatePoolPluginDeserialize); nvinfer1::Dims PoolPlugin::getOutputDimensions(int index, - const nvinfer1::Dims* inputDims, + const nvinfer1::Dims *inputDims, int nbInputs) { assert(nbInputs == 1); assert(index == 0); assert(inputDims[0].nbDims == 3); - nvinfer1::Dims const& input_dims = inputDims[0]; + nvinfer1::Dims const &input_dims = inputDims[0]; nvinfer1::Dims output_dims = input_dims; @@ -41,12 +41,12 @@ nvinfer1::Dims PoolPlugin::getOutputDimensions(int index, return output_dims; } -int PoolPlugin::enqueue(int batchSize, const void* const* inputs, - void** outputs, void* workspace, cudaStream_t stream) { - auto const& input_dims = this->getInputDims(0); +int PoolPlugin::enqueue(int batchSize, const void *const *inputs, + void **outputs, void *workspace, cudaStream_t stream) { + auto const &input_dims = this->getInputDims(0); int input_size = 0; - float const* idata = reinterpret_cast(inputs[0]); - float** odatas = reinterpret_cast(outputs); + float const *idata = reinterpret_cast(inputs[0]); + float **odatas = reinterpret_cast(outputs); std::vector input_shape = input_shape_; std::vector output_shape = output_shape_; @@ -72,6 +72,153 @@ int PoolPlugin::enqueue(int batchSize, const void* const* inputs, return cudaGetLastError() != cudaSuccess; } +// Dynamic Plugin below. +#if IS_TRT_VERSION_GE(6000) + +size_t PoolPluginDynamic::getSerializationSize() const { return 0; } + +void PoolPluginDynamic::serialize(void *buffer) const {} + +nvinfer1::DimsExprs PoolPluginDynamic::getOutputDimensions( + int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs, + nvinfer1::IExprBuilder &expr_builder) { + PADDLE_ENFORCE_EQ(nb_inputs, 1, + platform::errors::InvalidArgument( + "The Split plugin should be only one input.")); + + PADDLE_ENFORCE_EQ( + inputs[0].d[1]->isConstant(), true, + platform::errors::InvalidArgument("The channel dimension should be " + "static, but we found it's dynamic.")); + nvinfer1::DimsExprs output(inputs[0]); + if (is_global_) { + output.d[2] = expr_builder.constant(1); + output.d[3] = expr_builder.constant(1); + return output; + } + if (adaptive_) { + output.d[2] = expr_builder.constant(ksize_[0]); + output.d[3] = expr_builder.constant(ksize_[1]); + return output; + } + + auto stri_0 = expr_builder.constant(strides_[0]); + auto stri_1 = expr_builder.constant(strides_[1]); + + auto tmp1_0 = + expr_builder.constant((-ksize_[0] + 2 * paddings_[0]) / strides_[0] + 1); + auto tmp1_1 = + expr_builder.constant((-ksize_[1] + 2 * paddings_[1]) / strides_[1] + 1); + + auto tmp2_0 = expr_builder.constant( + (-ksize_[0] + 2 * paddings_[0] + strides_[0] - 1) / strides_[0] + 1); + auto tmp2_1 = expr_builder.constant( + (-ksize_[1] + 2 * paddings_[1] + strides_[1] - 1) / strides_[1] + 1); + + auto *a_d = expr_builder.operation(nvinfer1::DimensionOperation::kCEIL_DIV, + *inputs[0].d[2], *stri_0); + auto *b_d = expr_builder.operation(nvinfer1::DimensionOperation::kCEIL_DIV, + *inputs[0].d[3], *stri_1); + + if (!ceil_mode_) { + output.d[2] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM, + *a_d, *tmp1_0); + output.d[3] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM, + *b_d, *tmp1_1); + } else { + output.d[2] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM, + *a_d, *tmp2_0); + output.d[3] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM, + *b_d, *tmp2_1); + } + + return output; +} + +bool PoolPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs, + int nb_outputs) { + PADDLE_ENFORCE_NOT_NULL( + in_out, platform::errors::InvalidArgument( + "The input of swish plugin shoule not be nullptr.")); + + PADDLE_ENFORCE_LT( + pos, nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, nb_inputs + nb_outputs)); + (in_out && pos < (nb_inputs + nb_outputs)); + + return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) && + in_out[pos].format == nvinfer1::PluginFormat::kNCHW); +} + +nvinfer1::DataType PoolPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType *input_types, int nb_inputs) const { + PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument( + "The Pool Plugin only has one input, so the " + "index value should be 0, but get %d.", + index)); + PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT), true, + platform::errors::InvalidArgument( + "The input type should be half or float")); + return input_types[0]; +} + +int PoolPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, + const nvinfer1::PluginTensorDesc *output_desc, + const void *const *inputs, void *const *outputs, + void *workspace, cudaStream_t stream) { + auto input_dims = input_desc[0].dims; + int n = input_dims.d[0]; + int c = input_dims.d[1]; + int h = input_dims.d[2]; + int w = input_dims.d[3]; + + const float *input = static_cast(inputs[0]); + float *output = static_cast(outputs[0]); + + std::vector input_shape, output_shape; + for (int i = 0; i < input_dims.nbDims; i++) + input_shape.push_back(input_dims.d[i]); + output_shape = input_shape; + + std::vector ksize = ksize_; + std::vector paddings = paddings_; + if (is_global_) { + ksize[0] = h; + ksize[1] = w; + paddings[0] = 0; + paddings[1] = 0; + output_shape[2] = 1; + output_shape[3] = 1; + } else { + auto data_dim = CalcOutputSize({h, w}, ceil_mode_, adaptive_, ksize_, + strides_, paddings_); + output_shape[2] = data_dim[0]; + output_shape[3] = data_dim[1]; + } + + if (pool_type_ == "max") { + paddle::operators::math::MaxPool pool_process; + paddle::operators::math::Pool2dDirectCUDAFunctor< + paddle::operators::math::MaxPool, float> + pool2d_forward; + pool2d_forward(input, input_shape, output_shape, ksize, strides_, paddings, + pool_process, true, adaptive_, output, stream); + } else if (pool_type_ == "avg") { + paddle::operators::math::AvgPool pool_process; + paddle::operators::math::Pool2dDirectCUDAFunctor< + paddle::operators::math::AvgPool, float> + pool2d_forward; + pool2d_forward(input, input_shape, output_shape, ksize, strides_, paddings, + pool_process, true, adaptive_, output, stream); + } + + return cudaGetLastError() != cudaSuccess; +} +#endif + } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.h index 9b0591259abddec86850fe6036d60dd50faddfe9..6693a1fae4d4304af2f826894b119383ea704727 100644 --- a/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.h @@ -24,6 +24,37 @@ namespace inference { namespace tensorrt { namespace plugin { +static std::vector CalcOutputSize(const std::vector& input_shape, + const bool& ceil_mode, + const bool& adaptive, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings) { + std::vector output_shape = input_shape; + if (adaptive) { + output_shape[0] = ksize[0]; + output_shape[1] = ksize[1]; + } else { + int output_h, output_w; + if (!ceil_mode) { + output_h = (input_shape[0] - ksize[0] + 2 * paddings[0]) / strides[0] + 1; + output_w = (input_shape[1] - ksize[1] + 2 * paddings[1]) / strides[1] + 1; + } else { + output_h = + (input_shape[0] - ksize[0] + 2 * paddings[0] + strides[0] - 1) / + strides[0] + + 1; + output_w = + (input_shape[1] - ksize[1] + 2 * paddings[1] + strides[1] - 1) / + strides[1] + + 1; + } + output_shape[0] = output_h; + output_shape[1] = output_w; + } + return output_shape; +} + class PoolPlugin : public PluginTensorRT { protected: size_t getSerializationSize() override { @@ -36,7 +67,7 @@ class PoolPlugin : public PluginTensorRT { // TRT will call this func when we need to serialize the configuration of // tensorrt. - void serialize(void *buffer) override { + void serialize(void* buffer) override { SerializeValue(&buffer, getPluginType()); serializeBase(buffer); SerializeValue(&buffer, ceil_mode_); @@ -66,34 +97,16 @@ class PoolPlugin : public PluginTensorRT { paddings_(paddings), input_shape_(input_shape) { output_shape_ = input_shape_; - if (adaptive_) { - output_shape_[1] = ksize[0]; - output_shape_[2] = ksize[1]; - } else { - int output_h, output_w; - if (!ceil_mode_) { - output_h = - (input_shape[1] - ksize_[0] + 2 * paddings_[0]) / strides_[0] + 1; - output_w = - (input_shape[2] - ksize_[1] + 2 * paddings_[1]) / strides_[1] + 1; - } else { - output_h = - (input_shape[1] - ksize_[0] + 2 * paddings_[0] + strides_[0] - 1) / - strides_[0] + - 1; - output_w = - (input_shape[2] - ksize_[1] + 2 * paddings_[1] + strides_[1] - 1) / - strides_[1] + - 1; - } - output_shape_[1] = output_h; - output_shape_[2] = output_w; - } + std::vector output_shape = + CalcOutputSize({input_shape_[1], input_shape_[2]}, ceil_mode_, + adaptive_, ksize_, strides_, paddings_); + output_shape_[1] = output_shape[0]; + output_shape_[2] = output_shape[1]; } // It was used for tensorrt deserialization. // It should not be called by users. - PoolPlugin(void const *serialData, size_t serialLength) { + PoolPlugin(void const* serialData, size_t serialLength) { deserializeBase(serialData, serialLength); DeserializeValue(&serialData, &serialLength, &ceil_mode_); DeserializeValue(&serialData, &serialLength, &pool_type_); @@ -105,18 +118,18 @@ class PoolPlugin : public PluginTensorRT { DeserializeValue(&serialData, &serialLength, &output_shape_); } - PoolPlugin *clone() const override { + PoolPlugin* clone() const override { return new PoolPlugin(ceil_mode_, pool_type_, adaptive_, ksize_, strides_, paddings_, input_shape_); } - const char *getPluginType() const override { return "pool_plugin"; } + const char* getPluginType() const override { return "pool_plugin"; } int getNbOutputs() const override { return 1; } - nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override; int initialize() override { return 0; } - int enqueue(int batchSize, const void *const *inputs, void **outputs, - void *workspace, cudaStream_t stream) override; + int enqueue(int batchSize, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) override; private: bool ceil_mode_; @@ -129,6 +142,88 @@ class PoolPlugin : public PluginTensorRT { std::vector output_shape_; }; +#if IS_TRT_VERSION_GE(6000) +class PoolPluginDynamic : public DynamicPluginTensorRT { + public: + PoolPluginDynamic() {} + PoolPluginDynamic(const bool& ceil_mode, const std::string& pool_type, + const bool& adaptive, const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, const bool& is_global) + : ceil_mode_(ceil_mode), + pool_type_(pool_type), + adaptive_(adaptive), + ksize_(ksize), + strides_(strides), + paddings_(paddings), + is_global_(is_global) {} + + PoolPluginDynamic(void const* serialData, size_t serialLength) { + deserializeBase(serialData, serialLength); + DeserializeValue(&serialData, &serialLength, &ceil_mode_); + const char* pool_type; + DeserializeValue(&serialData, &serialLength, &pool_type); + pool_type_ = std::string(pool_type); + DeserializeValue(&serialData, &serialLength, &adaptive_); + DeserializeValue(&serialData, &serialLength, &ksize_); + DeserializeValue(&serialData, &serialLength, &strides_); + DeserializeValue(&serialData, &serialLength, &paddings_); + DeserializeValue(&serialData, &serialLength, &is_global_); + } + ~PoolPluginDynamic() {} + nvinfer1::IPluginV2DynamicExt* clone() const override { + return new PoolPluginDynamic(ceil_mode_, pool_type_, adaptive_, ksize_, + strides_, paddings_, is_global_); + } + + const char* getPluginType() const override { return "pool_plugin"; } + int getNbOutputs() const override { return 1; } + int initialize() override { return 0; } + + size_t getSerializationSize() const override; + void serialize(void* buffer) const override; + + nvinfer1::DimsExprs getOutputDimensions( + int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, int nbOutputs) override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) override {} + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const override { + return 0; + } + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const override; + + void destroy() override { delete this; } + + private: + bool ceil_mode_; + std::string pool_type_; + bool adaptive_; + std::vector ksize_; + std::vector strides_; + std::vector paddings_; + bool is_global_; +}; +#endif + } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu index 84f938eeb5fa50421a819978cd84c968919c96b3..1bde3c16d06dbeee0c4d87d68c0f48aae81cfd01 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu @@ -77,6 +77,84 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs, return cudaGetLastError() != cudaSuccess; } +#if IS_TRT_VERSION_GE(6000) + +int PReluPluginDynamic::initialize() { + cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size()); + cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float), + cudaMemcpyHostToDevice); + return 0; +} +size_t PReluPluginDynamic::getSerializationSize() const { return 0; } + +void PReluPluginDynamic::serialize(void *buffer) const {} + +nvinfer1::DimsExprs PReluPluginDynamic::getOutputDimensions( + int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs, + nvinfer1::IExprBuilder &expr_builder) { + return inputs[0]; +} + +bool PReluPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs, + int nb_outputs) { + PADDLE_ENFORCE_NOT_NULL( + in_out, platform::errors::InvalidArgument( + "The input of swish plugin shoule not be nullptr.")); + + PADDLE_ENFORCE_LT( + pos, nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, nb_inputs + nb_outputs)); + (in_out && pos < (nb_inputs + nb_outputs)); + + return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) && + in_out[pos].format == nvinfer1::PluginFormat::kNCHW); +} + +nvinfer1::DataType PReluPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType *input_types, int nb_inputs) const { + PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument( + "The PRelu Plugin only has one input, so the " + "index value should be 0, but get %d.", + index)); + PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT), true, + platform::errors::InvalidArgument( + "The input type should be half or float")); + return input_types[0]; +} + +int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, + const nvinfer1::PluginTensorDesc *output_desc, + const void *const *inputs, void *const *outputs, + void *workspace, cudaStream_t stream) { + auto input_dims = input_desc[0].dims; + const float *alpha = p_gpu_weight_; + const float *input = static_cast(inputs[0]); + float *output = static_cast(outputs[0]); + + std::vector input_shape; + for (int i = 0; i < input_dims.nbDims; i++) { + input_shape.push_back(input_dims.d[i]); + } + + if (mode_ == "channel") { + operators::math::PreluChannelWiseDirectCUDAFunctor + prelu_channel_wise; + prelu_channel_wise(stream, input, alpha, output, input_shape); + } else if (mode_ == "element") { + operators::math::PreluElementWiseDirectCUDAFunctor + prelu_element_wise; + prelu_element_wise(stream, input, alpha, output, input_shape); + } else { + operators::math::PreluScalarDirectCUDAFunctor prelu_scalar; + prelu_scalar(stream, input, alpha, output, input_shape); + } + return cudaGetLastError() != cudaSuccess; +} +#endif + } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h index a96649503f1c764e07370cb2b47b10f3dae72be4..4756ca2e0225795edc3bd3112b21e3b628ad5c0b 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h @@ -30,7 +30,7 @@ namespace plugin { class PReluPlugin : public PluginTensorRT { std::vector weight_; - float *p_gpu_weight_; + float* p_gpu_weight_; std::string mode_; protected: @@ -42,7 +42,7 @@ class PReluPlugin : public PluginTensorRT { // TRT will call this func when we need to serialize the configuration of // tensorrt. // It should not be called by users. - void serialize(void *buffer) override { + void serialize(void* buffer) override { SerializeValue(&buffer, getPluginType()); serializeBase(buffer); SerializeValue(&buffer, weight_); @@ -50,8 +50,8 @@ class PReluPlugin : public PluginTensorRT { } public: - PReluPlugin(const float *weight, const int weight_num, - std::string const &mode) + PReluPlugin(const float* weight, const int weight_num, + std::string const& mode) : mode_(mode) { weight_.resize(weight_num); std::copy(weight, weight + weight_num, weight_.data()); @@ -59,28 +59,96 @@ class PReluPlugin : public PluginTensorRT { // It was used for tensorrt deserialization. // It should not be called by users. - PReluPlugin(void const *serialData, size_t serialLength) { + PReluPlugin(void const* serialData, size_t serialLength) { deserializeBase(serialData, serialLength); DeserializeValue(&serialData, &serialLength, &weight_); - const char *prelu_mode; + const char* prelu_mode; DeserializeValue(&serialData, &serialLength, &prelu_mode); mode_ = std::string(prelu_mode); } ~PReluPlugin() { cudaFree(p_gpu_weight_); } int initialize() override; - PReluPlugin *clone() const override { + PReluPlugin* clone() const override { return new PReluPlugin(weight_.data(), weight_.size(), mode_); } - const char *getPluginType() const override { return "prelu_plugin"; } + const char* getPluginType() const override { return "prelu_plugin"; } int getNbOutputs() const override { return 1; } - nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override; - int enqueue(int batchSize, const void *const *inputs, void **outputs, - void *workspace, cudaStream_t stream) override; + int enqueue(int batchSize, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) override; }; +#if IS_TRT_VERSION_GE(6000) +class PReluPluginDynamic : public DynamicPluginTensorRT { + public: + PReluPluginDynamic(const float* weight, const int weight_num, + std::string const& mode) + : mode_(mode) { + weight_.resize(weight_num); + std::copy(weight, weight + weight_num, weight_.data()); + } + + // It was used for tensorrt deserialization. + // It should not be called by users. + PReluPluginDynamic(void const* serialData, size_t serialLength) { + deserializeBase(serialData, serialLength); + DeserializeValue(&serialData, &serialLength, &weight_); + const char* prelu_mode; + DeserializeValue(&serialData, &serialLength, &prelu_mode); + mode_ = std::string(prelu_mode); + } + ~PReluPluginDynamic() { cudaFree(p_gpu_weight_); } + nvinfer1::IPluginV2DynamicExt* clone() const override { + return new PReluPluginDynamic(weight_.data(), weight_.size(), mode_); + } + + const char* getPluginType() const override { return "prelu_plugin"; } + int getNbOutputs() const override { return 1; } + int initialize() override; + + size_t getSerializationSize() const override; + void serialize(void* buffer) const override; + + nvinfer1::DimsExprs getOutputDimensions( + int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, int nbOutputs) override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) override {} + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const override { + return 0; + } + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const override; + + void destroy() override { delete this; } + + private: + std::vector weight_; + float* p_gpu_weight_; + std::string mode_; +}; +#endif + } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index fed8bd1145a2a75dbf2e0390cdcfd478d40763a1..13c48fcf576e1e05137875c00d67b81a47db1160 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -373,13 +373,18 @@ if(WITH_GPU AND TENSORRT_FOUND) EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${TRT_MODEL_QUANT_RESNET_DIR}) + set(TEST_TRT_DYNAMIC_MODEL2 "${TRT_MODEL_INSTALL_DIR}/complex_model_dynamic") + if (NOT EXISTS ${TEST_TRT_DYNAMIC_MODEL2}) + inference_download_and_uncompress(${TEST_TRT_DYNAMIC_MODEL2} ${INFERENCE_URL}/tensorrt_test "complex_model_dynamic2.tar.gz") + endif() + set(TEST_TRT_DYNAMIC_MODEL "${TRT_MODEL_INSTALL_DIR}/conv_bn_swish_split_gelu") if (NOT EXISTS ${TEST_TRT_DYNAMIC_MODEL}) inference_download_and_uncompress(${TEST_TRT_DYNAMIC_MODEL} ${INFERENCE_URL}/tensorrt_test "conv_bn_swish_split_gelu.tar.gz") endif() inference_analysis_test(trt_dynamic_shape_test SRCS trt_dynamic_shape_test.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${TEST_TRT_DYNAMIC_MODEL}) + ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}) set(TEST_TRT_ERNIE_MODEL "${TRT_MODEL_INSTALL_DIR}/ernie_test") if (NOT EXISTS ${TEST_TRT_ERNIE_MODEL}) diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_test.cc index 59866fbb8f8fd0468d566aba756a6f17cb229c3e..989fa028a00b38f4f2bb0e45004c19be3d14b788 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_test.cc @@ -22,7 +22,8 @@ namespace paddle { namespace inference { void TestDynamic(bool with_dynamic = true) { - std::string model_dir = FLAGS_infer_model + "/conv_bn_swish_split_gelu"; + std::string model_dir = + FLAGS_infer_model + "/conv_bn_swish_split_gelu/conv_bn_swish_split_gelu"; AnalysisConfig config; config.EnableUseGpu(100, 0); config.SetModel(model_dir + "/model", model_dir + "/params"); @@ -67,8 +68,69 @@ void TestDynamic(bool with_dynamic = true) { output_t->copy_to_cpu(out_data.data()); } +void TestDynamic2() { + std::string model_dir = + FLAGS_infer_model + "/complex_model_dynamic/complex_model_dynamic2"; + AnalysisConfig config; + config.EnableUseGpu(100, 0); + config.SetModel(model_dir + "/model", model_dir + "/params"); + config.SwitchUseFeedFetchOps(false); + // Set the input's min, max, opt shape + int batch_size = 1; + std::map> min_input_shape = { + {"image", {1, 3, 3, 3}}, {"in1", {1, 2, 1, 1}}, {"in2", {1, 2, 1, 1}}}; + std::map> max_input_shape = { + {"image", {1, 3, 10, 10}}, {"in1", {1, 2, 1, 1}}, {"in2", {1, 2, 1, 1}}}; + std::map> opt_input_shape = { + {"image", {1, 3, 5, 5}}, {"in1", {1, 2, 1, 1}}, {"in2", {1, 2, 1, 1}}}; + config.EnableTensorRtEngine(1 << 30, batch_size, 0, + AnalysisConfig::Precision::kFloat32, false, true); + + config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, + opt_input_shape); + + auto predictor = CreatePaddlePredictor(config); + int channels = 3; + int height = 5; + int width = 5; + int input_num = channels * height * width * 1; + + float *input = new float[input_num]; + memset(input, 0, input_num * sizeof(float)); + auto input_names = predictor->GetInputNames(); + auto input_t = predictor->GetInputTensor(input_names[0]); + input_t->Reshape({batch_size, channels, height, width}); + input_t->copy_from_cpu(input); + + auto input_t1 = predictor->GetInputTensor(input_names[1]); + input_t1->Reshape({batch_size, 2, 1, 1}); + std::vector first; + for (int i = 0; i < batch_size * 2; i++) first.push_back(1.0); + input_t1->copy_from_cpu(first.data()); + + auto input_t2 = predictor->GetInputTensor(input_names[2]); + input_t2->Reshape({batch_size, 2, 1, 1}); + input_t2->copy_from_cpu(first.data()); + + ASSERT_TRUE(predictor->ZeroCopyRun()); + + std::vector out_data; + auto output_names = predictor->GetOutputNames(); + auto output_t = predictor->GetOutputTensor(output_names[0]); + std::vector output_shape = output_t->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, + std::multiplies()); + out_data.resize(out_num); + output_t->copy_to_cpu(out_data.data()); + std::vector result = {0.617728, 1.63504, 2.15771, 0.535556}; + for (size_t i = 0; i < out_data.size(); i++) { + EXPECT_NEAR(result[i], out_data[i], 1e-6); + } +} + TEST(AnalysisPredictor, trt_dynamic) { TestDynamic(true); } TEST(AnalysisPredictor, trt_static) { TestDynamic(false); } +TEST(AnalysisPredictor, trt_dynamic2) { TestDynamic2(); } } // namespace inference } // namespace paddle