diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index dda4be8f81c63fe844d2f0035ac0c3b88038e11d..ad0647236acb969dd32b7564a943db14ea83ee65 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1415,6 +1415,7 @@ USE_TRT_CONVERTER(tile); USE_TRT_CONVERTER(conv3d); USE_TRT_CONVERTER(conv3d_transpose); USE_TRT_CONVERTER(mish); +USE_TRT_CONVERTER(deformable_conv); USE_TRT_CONVERTER(pool3d) #endif diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index b6aa0a230cc2d5454dfd3c38de930c7c17481464..a885b69fa7fbcc19e4fe4825410d2f862ba8c568 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -20,6 +20,7 @@ nv_library(tensorrt_converter mish_op.cc nearest_interp_v2_op.cc pool3d_op.cc + deformable_conv_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc b/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..02d460ffa1cbbfc3e3aa18e4b96fc356cc593da7 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc @@ -0,0 +1,111 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class DeformableConvOpConverter : public OpConverter { + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a deformable conv op to tensorrt plugin"; + + framework::OpDesc op_desc(op, nullptr); + std::string input_name = op_desc.Input("Input").front(); + std::string offset_name = op_desc.Input("Offset").front(); + std::string mask_name = op_desc.Input("Mask").front(); + std::string filter_name = op_desc.Input("Filter").front(); + + auto* input_tensor = engine_->GetITensor(input_name); + auto* offset_tensor = engine_->GetITensor(offset_name); + auto* mask_tensor = engine_->GetITensor(mask_name); + auto* filter_var = scope.FindVar(filter_name); + auto* filter_tensor = filter_var->GetMutable(); + + float* filter_data = + engine_->GetWeightCPUData(filter_name, filter_tensor, false); + + const int c_o = filter_tensor->dims()[0]; + const int c_i = filter_tensor->dims()[1]; + const int k_h = filter_tensor->dims()[2]; + const int k_w = filter_tensor->dims()[3]; + std::vector kernel_dims = {c_o, c_i, k_h, k_w}; + + auto strides = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("strides")); + auto paddings = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("paddings")); + auto dilations = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("dilations")); + + auto groups = BOOST_GET_CONST(int, op_desc.GetAttr("groups")); + auto deformable_groups = + BOOST_GET_CONST(int, op_desc.GetAttr("deformable_groups")); + auto im2col_step = BOOST_GET_CONST(int, op_desc.GetAttr("im2col_step")); + + nvinfer1::Weights weights; + weights.count = filter_tensor->numel(); + if (engine_->WithFp16()) { + auto half_filter_data = new half[filter_tensor->numel()]; + for (int i = 0; i < filter_tensor->numel(); i++) { + half_filter_data[i] = static_cast(filter_data[i]); + } + weights.type = nvinfer1::DataType::kHALF; + weights.values = half_filter_data; + } else { + weights.type = nvinfer1::DataType::kFLOAT; + weights.values = filter_data; + } + auto* deformable_conv_plugin = new plugin::DeformableConvPlugin( + engine_->WithFp16() ? nvinfer1::DataType::kHALF + : nvinfer1::DataType::kFLOAT, + weights, kernel_dims, strides, paddings, dilations, groups, + deformable_groups, im2col_step); + + std::vector deformable_conv_inputs; + deformable_conv_inputs.push_back(input_tensor); + deformable_conv_inputs.push_back(offset_tensor); + deformable_conv_inputs.push_back(mask_tensor); + + auto* deformable_conv_layer = engine_->network()->addPluginV2( + deformable_conv_inputs.data(), deformable_conv_inputs.size(), + *deformable_conv_plugin); + + std::vector output_names; + output_names.push_back(op_desc.Output("Output").front()); + + RreplenishLayerAndOutput(deformable_conv_layer, "deformable_conv", + output_names, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(deformable_conv, DeformableConvOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 13504f444109b70e645b84d012267e0b1176cc22..e9b1c90ab086c8cbab9daa32306bd119ccb1dbb5 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -143,7 +143,8 @@ struct SimpleOpTypeSetTeller : public Teller { "conv3d_transpose", "mish", "nearest_interp_v2", - "pool3d"}; + "pool3d", + "deformable_conv"}; }; bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, @@ -332,6 +333,51 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, #endif } + if (op_type == "deformable_conv") { + if (with_dynamic_shape) { + VLOG(3) << "Deformable conv trt plugin does not support dynamic shape"; + return false; + } + auto* block = desc.Block(); + auto input_name = desc.Input("Input")[0]; + auto* input_desc = block->FindVar(input_name); + const auto input_shape = input_desc->GetShape(); + + if (input_shape.size() != 4) { + VLOG(3) << "Input of deformable conv should be 4-D Tensor, but got " + << input_shape.size(); + return false; + } + + auto filter_name = desc.Input("Filter")[0]; + auto* filter_desc = block->FindVar(filter_name); + const auto filter_shape = filter_desc->GetShape(); + + int groups = BOOST_GET_CONST(int, desc.GetAttr("groups")); + if (input_shape[1] != filter_shape[1] * groups) { + VLOG(3) << "The number of input channels should be equal to filter " + << "channels * groups. But got input channels " + << input_shape[1] << "filter channels " << filter_shape[1]; + return false; + } + + const std::vector strides = + BOOST_GET_CONST(std::vector, desc.GetAttr("strides")); + if (strides.size() != 2) { + VLOG(3) << "The size of strides should be 2, but got " + << strides.size(); + return false; + } + + const std::vector paddings = + BOOST_GET_CONST(std::vector, desc.GetAttr("paddings")); + if (paddings.size() != 2) { + VLOG(3) << "The size of paddings shoule be 2, but got " + << paddings.size(); + return false; + } + } + if (op_type == "matmul") { auto* block = desc.Block(); if (block == nullptr) { diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 9e93894e623c0070c17d03788f5c531db60d80bc..3eece7e500e687715f68d9ae158dd0bec449de67 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -11,6 +11,7 @@ nv_library(tensorrt_plugin gather_nd_op_plugin.cu mish_op_plugin.cu pool3d_op_plugin.cu + deformable_conv_op_plugin.cu DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..b090ad91454a596fe411ae61e7cac4c3385cc6d5 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu @@ -0,0 +1,618 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include + +#include "paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +static inline int ConvOutputSize(int input_size, int filter_size, int dilation, + int padding, int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + int output_size = (input_size + 2 * padding - dkernel) / stride + 1; + return output_size; +} + +nvinfer1::Weights DeformableConvPlugin::copyToDevice(const void* hostData, + size_t count) { + int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2); + void* deviceData; + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc(&deviceData, count * num_bytes)); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpy( + deviceData, hostData, count * num_bytes, cudaMemcpyHostToDevice)); + return nvinfer1::Weights{data_type_, deviceData, int64_t(count)}; +} + +void DeformableConvPlugin::serializeFromDevice( + void** hostBuffer, const nvinfer1::Weights& deviceWeights) const { + int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemcpy(static_cast(*hostBuffer), deviceWeights.values, + deviceWeights.count * num_bytes, cudaMemcpyDeviceToHost)); + hostBuffer += deviceWeights.count * num_bytes; +} + +nvinfer1::Weights DeformableConvPlugin::deserializeToDevice( + const void** hostBuffer, size_t count) { + int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2); + nvinfer1::Weights w = + copyToDevice(static_cast(*hostBuffer), count); + hostBuffer += count * num_bytes; + return w; +} + +DeformableConvPlugin::DeformableConvPlugin( + const nvinfer1::DataType data_type, const nvinfer1::Weights& weights, + const std::vector& kernel_dims, const std::vector& strides, + const std::vector& paddings, const std::vector& dilations, + const int groups, const int deformable_groups, const int im2col_step) + : data_type_(data_type), + groups_(groups), + deformable_groups_(deformable_groups), + im2col_step_(im2col_step) { + weights_ = copyToDevice(weights.values, weights.count); + kernel_dims_.insert(kernel_dims_.end(), kernel_dims.cbegin(), + kernel_dims.cend()); + + strides_.insert(strides_.end(), strides.cbegin(), strides.cend()); + paddings_.insert(paddings_.end(), paddings.cbegin(), paddings.cend()); + dilations_.insert(dilations_.end(), dilations.cbegin(), dilations.cend()); + PADDLE_ENFORCE_EQ(data_type_ == nvinfer1::DataType::kFLOAT || + data_type_ == nvinfer1::DataType::kHALF, + true, platform::errors::InvalidArgument( + "The DeformableConv TRT Plugin's input type " + "should be float or half.")); + PADDLE_ENFORCE_EQ( + paddings_.size(), strides_.size(), + platform::errors::InvalidArgument( + "The size of paddings (%d) is not equal to the size of strides (%d).", + paddings_.size(), strides_.size())); +} + +DeformableConvPlugin::DeformableConvPlugin( + const nvinfer1::DataType data_type, const nvinfer1::Weights& weights, + const std::vector& kernel_dims, const std::vector& strides, + const std::vector& paddings, const std::vector& dilations, + const int groups, const int deformable_groups, const int im2col_step, + const std::vector& input_dim, const std::vector& offset_dim, + const std::vector& mask_dim, const std::vector& output_dim) + : data_type_(data_type), + groups_(groups), + deformable_groups_(deformable_groups), + im2col_step_(im2col_step) { + weights_ = copyToDevice(weights.values, weights.count); + kernel_dims_.insert(kernel_dims_.end(), kernel_dims.cbegin(), + kernel_dims.cend()); + + strides_.insert(strides_.end(), strides.cbegin(), strides.cend()); + paddings_.insert(paddings_.end(), paddings.cbegin(), paddings.cend()); + dilations_.insert(dilations_.end(), dilations.cbegin(), dilations.cend()); + input_dim_.insert(input_dim_.end(), input_dim.cbegin(), input_dim.cend()); + offset_dim_.insert(offset_dim_.end(), offset_dim.cbegin(), offset_dim.cend()); + mask_dim_.insert(mask_dim_.end(), mask_dim.cbegin(), mask_dim.cend()); + output_dim_.insert(output_dim_.end(), output_dim.cbegin(), output_dim.cend()); + PADDLE_ENFORCE_EQ(data_type_ == nvinfer1::DataType::kFLOAT || + data_type_ == nvinfer1::DataType::kHALF, + true, platform::errors::InvalidArgument( + "The DeformableConv TRT Plugin's input type " + "should be float or half.")); + PADDLE_ENFORCE_EQ( + paddings_.size(), strides_.size(), + platform::errors::InvalidArgument( + "The size of paddings (%d) is not equal to the size of strides (%d).", + paddings_.size(), strides_.size())); +} + +DeformableConvPlugin::DeformableConvPlugin(const void* data, size_t length) { + DeserializeValue(&data, &length, &data_type_); + DeserializeValue(&data, &length, &strides_); + DeserializeValue(&data, &length, &paddings_); + DeserializeValue(&data, &length, &dilations_); + DeserializeValue(&data, &length, &groups_); + DeserializeValue(&data, &length, &deformable_groups_); + DeserializeValue(&data, &length, &im2col_step_); + DeserializeValue(&data, &length, &kernel_dims_); + int64_t count; + DeserializeValue(&data, &length, &count); + weights_ = deserializeToDevice(&data, count); + DeserializeValue(&data, &length, &input_dim_); + DeserializeValue(&data, &length, &offset_dim_); + DeserializeValue(&data, &length, &mask_dim_); + DeserializeValue(&data, &length, &output_dim_); +} + +DeformableConvPlugin::~DeformableConvPlugin() { + if (weights_.values) { + cudaFree(const_cast(weights_.values)); + weights_.values = nullptr; + } +} + +const char* DeformableConvPlugin::getPluginType() const TRT_NOEXCEPT { + return "deformable_conv_plugin"; +} + +const char* DeformableConvPlugin::getPluginVersion() const TRT_NOEXCEPT { + return "1"; +} + +int DeformableConvPlugin::getNbOutputs() const TRT_NOEXCEPT { return 1; } + +nvinfer1::Dims DeformableConvPlugin::getOutputDimensions( + int index, const nvinfer1::Dims* inputs, int nb_input_dims) TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(nb_input_dims, 3, + platform::errors::InvalidArgument( + "The number of inputs should be equal to 3, but got %d", + nb_input_dims)); + nvinfer1::Dims ret; + ret.nbDims = inputs[0].nbDims; + ret.d[0] = kernel_dims_[0]; + ret.d[1] = ConvOutputSize(inputs[0].d[1], kernel_dims_[2], dilations_[0], + paddings_[0], strides_[0]); + ret.d[2] = ConvOutputSize(inputs[0].d[2], kernel_dims_[3], dilations_[1], + paddings_[1], strides_[1]); + return ret; +} + +bool DeformableConvPlugin::supportsFormat( + nvinfer1::DataType type, nvinfer1::TensorFormat format) const TRT_NOEXCEPT { + return ((type == data_type_ || type == nvinfer1::DataType::kINT32) && + format == nvinfer1::TensorFormat::kLINEAR); +} + +size_t DeformableConvPlugin::getWorkspaceSize(int max_batch_size) const + TRT_NOEXCEPT { + int c_i = input_dim_[0], h_i = input_dim_[1], w_i = input_dim_[2]; + int k_h = kernel_dims_[2], k_w = kernel_dims_[3]; + int c_o = output_dim_[0], h_o = output_dim_[1], w_o = output_dim_[2]; + int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2); + size_t data_col_size = static_cast(c_i * k_h * k_w * im2col_step_ * + h_o * w_o * num_bytes); + return data_col_size; +} + +int DeformableConvPlugin::enqueue(int batch_size, const void* const* inputs, +#if IS_TRT_VERSION_LT(8000) + void** outputs, void* workspace, +#else + void* const* outputs, void* workspace, +#endif + cudaStream_t stream) TRT_NOEXCEPT { + if (data_type_ == nvinfer1::DataType::kFLOAT) { + enqueue_impl(batch_size, inputs, outputs, workspace, stream); + } else if (data_type_ == nvinfer1::DataType::kHALF) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + enqueue_impl(batch_size, inputs, outputs, workspace, stream); +#else + PADDLE_THROW(platform::errors::InvalidArgument( + "Current CUDA arch dose not support fp16. Please use fp32 instead.")); +#endif + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The DeformableConv TRT Plugin's input type should be float or half.")); + } + return cudaGetLastError() != cudaSuccess; +} + +template +__device__ T kFloor(T x); + +template <> +__device__ half kFloor(half x) { + return hfloor(x); +} + +template <> +__device__ float kFloor(float x) { + return floor(x); +} + +template +__device__ T DmcnIm2colBilinear(const T* bottom_data, const int data_width, + const int height, const int width, T h, T w) { + int h_low = kFloor(h); + int w_low = kFloor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + T h_low_t = h_low, w_low_t = w_low, one = 1.0f; + T lh = h - h_low_t; + T lw = w - w_low_t; + T hh = one - lh, hw = one - lw; + + T v1 = 0; + if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low]; + T v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + T v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + T v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__global__ void ModulatedDeformableIm2colGpuKernel( + const int nthreads, const T* data_im, const T* data_offset, + const T* data_mask, const int height, const int width, const int kernel_h, + const int kernel_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int num_channels, const int deformable_group, const int height_col, + const int width_col, T* data_col) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + + T minus_one = -1.0f, height_t = height, width_t = width; + for (size_t i = index; i < nthreads; i += offset) { + const int w_col = i % width_col; + const int h_col = (i / width_col) % height_col; + const int b_col = (i / width_col) / height_col % batch_size; + const int c_im = (i / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + T* data_col_ptr = + data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + const T* data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const T* data_offset_ptr = + data_offset + + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + const T* data_mask_ptr = + data_mask + + (b_col * deformable_group + deformable_group_index) * kernel_h * + kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + T val = 0; + T h_im_t = h_in + i * dilation_h, w_im_t = w_in + j * dilation_w; + const T h_im = h_im_t + offset_h; + const T w_im = w_im_t + offset_w; + if (h_im > minus_one && w_im > minus_one && h_im < height_t && + w_im < width_t) { + val = DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, + w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +template +void gemm_impl(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, const T* alpha, + const T* A, int lda, const T* B, int ldb, const T* beta, T* C, + int ldc); + +template <> +void gemm_impl(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float* alpha, const float* A, int lda, + const float* B, int ldb, const float* beta, float* C, + int ldc) { + platform::dynload::cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, + B, ldb, beta, C, ldc); +} + +template <> +void gemm_impl(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const half* alpha, const half* A, int lda, const half* B, + int ldb, const half* beta, half* C, int ldc) { + platform::dynload::cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, + B, ldb, beta, C, ldc); +} + +template +int DeformableConvPlugin::enqueue_impl(int batch_size, + const void* const* inputs, + void** outputs, void* workspace, + cudaStream_t stream) { + const T* input = reinterpret_cast(inputs[0]); + const T* offset = reinterpret_cast(inputs[1]); + const T* mask = reinterpret_cast(inputs[2]); + const T* filter = reinterpret_cast(weights_.values); + T* output = reinterpret_cast(outputs[0]); + + int c_i = input_dim_[0], h_i = input_dim_[1], w_i = input_dim_[2]; + int k_h = kernel_dims_[2], k_w = kernel_dims_[3]; + int c_o = output_dim_[0], h_o = output_dim_[1], w_o = output_dim_[2]; + + int input_stride = c_i * h_i * w_i; + int offset_stride = offset_dim_[0] * offset_dim_[1] * offset_dim_[2]; + int mask_stride = mask_dim_[0] * mask_dim_[1] * mask_dim_[2]; + int output_stride = c_o * h_o * w_o; + + int M = c_o / groups_; + int N = im2col_step_ * h_o * w_o; + int K = c_i * k_h * k_w / groups_; + + // c_i / deformable_groups + int channel_per_deformable_group = c_i / deformable_groups_; + // c_i * im2col_step * h_o * w_o + int num_kernels = c_i * im2col_step_ * h_o * w_o; + + int blocks = NumBlocks(num_kernels); + int threads = kNumCUDAThreads; + + T alpha = static_cast(1.0f); + T beta = static_cast(0.0f); + + for (int i = 0; i < batch_size / im2col_step_; ++i) { + const T* data_im = input + i * im2col_step_ * input_stride; + const T* data_offset = offset + i * im2col_step_ * offset_stride; + const T* data_mask = mask + i * im2col_step_ * mask_stride; + T* data_col = reinterpret_cast(workspace); + + ModulatedDeformableIm2colGpuKernel<<>>( + num_kernels, data_im, data_offset, data_mask, h_i, w_i, k_h, k_w, + paddings_[0], paddings_[1], strides_[0], strides_[1], dilations_[0], + dilations_[1], channel_per_deformable_group, im2col_step_, c_i, + deformable_groups_, h_o, w_o, data_col); + + for (int g = 0; g < groups_; ++g) { + const T* weight = filter + g * M * K; + const T* col = data_col + g * K * N; + T* out = output + i * im2col_step_ * output_stride + g * M * N; + gemm_impl(cublasHandle_, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &alpha, + col, N, weight, K, &beta, out, N); + } + } + return 0; +} + +int DeformableConvPlugin::initialize() TRT_NOEXCEPT { return 0; } + +void DeformableConvPlugin::terminate() TRT_NOEXCEPT {} + +size_t DeformableConvPlugin::getSerializationSize() const TRT_NOEXCEPT { + size_t serialize_size = 0; + serialize_size += SerializedSize(data_type_); + serialize_size += SerializedSize(strides_); + serialize_size += SerializedSize(paddings_); + serialize_size += SerializedSize(dilations_); + serialize_size += SerializedSize(groups_); + serialize_size += SerializedSize(deformable_groups_); + serialize_size += SerializedSize(im2col_step_); + serialize_size += SerializedSize(kernel_dims_); + serialize_size += SerializedSize(weights_.count); + int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2); + serialize_size += weights_.count * num_bytes; + serialize_size += SerializedSize(input_dim_); + serialize_size += SerializedSize(offset_dim_); + serialize_size += SerializedSize(mask_dim_); + serialize_size += SerializedSize(output_dim_); + return serialize_size; +} + +void DeformableConvPlugin::serialize(void* buffer) const TRT_NOEXCEPT { + SerializeValue(&buffer, data_type_); + SerializeValue(&buffer, strides_); + SerializeValue(&buffer, paddings_); + SerializeValue(&buffer, dilations_); + SerializeValue(&buffer, groups_); + SerializeValue(&buffer, deformable_groups_); + SerializeValue(&buffer, im2col_step_); + SerializeValue(&buffer, kernel_dims_); + SerializeValue(&buffer, weights_.count); + serializeFromDevice(&buffer, weights_); + SerializeValue(&buffer, input_dim_); + SerializeValue(&buffer, offset_dim_); + SerializeValue(&buffer, mask_dim_); + SerializeValue(&buffer, output_dim_); +} + +void DeformableConvPlugin::destroy() TRT_NOEXCEPT {} + +void DeformableConvPlugin::setPluginNamespace(const char* lib_namespace) + TRT_NOEXCEPT { + namespace_ = std::string(lib_namespace); +} + +const char* DeformableConvPlugin::getPluginNamespace() const TRT_NOEXCEPT { + return namespace_.c_str(); +} + +nvinfer1::DataType DeformableConvPlugin::getOutputDataType( + int index, const nvinfer1::DataType* input_type, + int nb_inputs) const TRT_NOEXCEPT { + return data_type_; +} + +bool DeformableConvPlugin::isOutputBroadcastAcrossBatch( + int output_index, const bool* input_is_broadcast, + int nb_inputs) const TRT_NOEXCEPT { + return false; +} + +bool DeformableConvPlugin::canBroadcastInputAcrossBatch(int input_index) const + TRT_NOEXCEPT { + return false; +} + +void DeformableConvPlugin::attachToContext( + cudnnContext* cudnnContext, cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT { + cublasHandle_ = cublasContext; +} + +void DeformableConvPlugin::configurePlugin( + const nvinfer1::Dims* input_dims, int nb_inputs, + const nvinfer1::Dims* output_dims, int nb_outputs, + const nvinfer1::DataType* input_types, + const nvinfer1::DataType* output_types, const bool* input_is_broadcast, + const bool* output_is_broadcast, nvinfer1::PluginFormat float_format, + int max_batct_size) TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ( + nb_inputs, 3, + platform::errors::InvalidArgument( + "The number of inputs should be equal to 3, but got %d", nb_inputs)); + PADDLE_ENFORCE_EQ( + nb_outputs, 1, + platform::errors::InvalidArgument( + "The number of inputs should be equal to 1, but got %d", nb_outputs)); + + for (int i = 0; i < input_dims[0].nbDims; i++) { + input_dim_.push_back(input_dims[0].d[i]); + } + for (int i = 0; i < input_dims[1].nbDims; i++) { + offset_dim_.push_back(input_dims[1].d[i]); + } + for (int i = 0; i < input_dims[2].nbDims; i++) { + mask_dim_.push_back(input_dims[2].d[i]); + } + for (int i = 0; i < output_dims[0].nbDims; i++) { + output_dim_.push_back(output_dims[0].d[i]); + } +} + +nvinfer1::IPluginV2Ext* DeformableConvPlugin::clone() const TRT_NOEXCEPT { + return new DeformableConvPlugin(data_type_, weights_, kernel_dims_, strides_, + paddings_, dilations_, groups_, + deformable_groups_, im2col_step_, input_dim_, + offset_dim_, mask_dim_, output_dim_); +} + +DeformableConvPluginCreator::DeformableConvPluginCreator() TRT_NOEXCEPT {} + +void DeformableConvPluginCreator::setPluginNamespace(const char* lib_namespace) + TRT_NOEXCEPT { + namespace_ = std::string(lib_namespace); +} + +const char* DeformableConvPluginCreator::getPluginNamespace() const + TRT_NOEXCEPT { + return namespace_.c_str(); +} + +const char* DeformableConvPluginCreator::getPluginName() const TRT_NOEXCEPT { + return "deformable_conv_plugin"; +} + +const char* DeformableConvPluginCreator::getPluginVersion() const TRT_NOEXCEPT { + return "1"; +} + +const nvinfer1::PluginFieldCollection* +DeformableConvPluginCreator::getFieldNames() TRT_NOEXCEPT { + return &field_collection_; +} + +nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT { + const nvinfer1::PluginField* fields = fc->fields; + + nvinfer1::DataType data_type; + std::vector strides, paddings, dilations, kernel_dims; + nvinfer1::Weights weights; + int groups = -1; + int deformable_groups = -1; + int im2col_step = -1; + + for (int i = 0; i < fc->nbFields; ++i) { + const std::string field_name(fc->fields[i].name); + if (field_name.compare("data_type") == 0) { + data_type = *static_cast(fc->fields[i].data); + } else if (field_name.compare("strides")) { + const int length = fc->fields[i].length; + const int* data = static_cast(fc->fields[i].data); + strides.insert(strides.end(), data, data + length); + } else if (field_name.compare("paddings")) { + const int length = fc->fields[i].length; + const int* data = static_cast(fc->fields[i].data); + paddings.insert(paddings.end(), data, data + length); + } else if (field_name.compare("dilations")) { + const int length = fc->fields[i].length; + const int* data = static_cast(fc->fields[i].data); + dilations.insert(dilations.end(), data, data + length); + } else if (field_name.compare("groups")) { + groups = *static_cast(fc->fields[i].data); + } else if (field_name.compare("deformable_groups")) { + deformable_groups = *static_cast(fc->fields[i].data); + } else if (field_name.compare("im2col_step")) { + im2col_step = *static_cast(fc->fields[i].data); + } else if (field_name.compare("kernel_dims")) { + const int length = fc->fields[i].length; + const int* data = static_cast(fc->fields[i].data); + kernel_dims.insert(kernel_dims.end(), data, data + length); + } else if (field_name.compare("weights")) { + weights.count = fc->fields[i].length; + weights.values = fc->fields[i].data; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unknown plugin field name [%s] in the DeformableConv TRT Plugin.", + field_name)); + } + } + weights.type = data_type; + return new DeformableConvPlugin(data_type, weights, kernel_dims, strides, + paddings, dilations, groups, + deformable_groups, im2col_step); +} + +nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::deserializePlugin( + const char* name, const void* serial_data, + size_t serial_length) TRT_NOEXCEPT { + auto plugin = new DeformableConvPlugin(serial_data, serial_length); + plugin->setPluginNamespace(namespace_.c_str()); + return plugin; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..9b04d6fb8ca2272ece0d300d6033197371a7632b --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/platform/dynload/cublas.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class DeformableConvPlugin : public nvinfer1::IPluginV2Ext { + public: + explicit DeformableConvPlugin( + const nvinfer1::DataType data_type, const nvinfer1::Weights& weights, + const std::vector& kernel_dims, const std::vector& strides, + const std::vector& paddings, const std::vector& dilations, + const int groups, const int deformable_groups, const int im2col_step); + explicit DeformableConvPlugin( + const nvinfer1::DataType data_type, const nvinfer1::Weights& weights, + const std::vector& kernel_dims, const std::vector& strides, + const std::vector& paddings, const std::vector& dilations, + const int groups, const int deformable_groups, const int im2col_step, + const std::vector& input_dim, const std::vector& offset_dim, + const std::vector& mask_dim, const std::vector& output_dim); + DeformableConvPlugin(const void* data, size_t length); + ~DeformableConvPlugin() override; + + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, + int nb_input_dims) TRT_NOEXCEPT override; + bool supportsFormat(nvinfer1::DataType type, nvinfer1::TensorFormat format) + const TRT_NOEXCEPT override; + size_t getWorkspaceSize(int max_batch_size) const TRT_NOEXCEPT override; +#if IS_TRT_VERSION_LT(8000) + int enqueue(int batch_size, const void* const* inputs, void** outputs, +#else + int enqueue(int batch_size, const void* const* inputs, void* const* outputs, +#endif + void* workspace, cudaStream_t stream) TRT_NOEXCEPT override; + int initialize() TRT_NOEXCEPT override; + void terminate() TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + void destroy() TRT_NOEXCEPT override; + void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override; + const char* getPluginNamespace() const TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType( + int index, const nvinfer1::DataType* input_type, + int nb_inputs) const TRT_NOEXCEPT override; + bool isOutputBroadcastAcrossBatch(int output_index, + const bool* input_is_broadcast, + int nb_inputs) const TRT_NOEXCEPT override; + bool canBroadcastInputAcrossBatch(int input_index) const + TRT_NOEXCEPT override; + + void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) + TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::Dims* input_dims, int nb_inputs, + const nvinfer1::Dims* output_dims, int nb_outputs, + const nvinfer1::DataType* input_types, + const nvinfer1::DataType* output_types, + const bool* input_is_broadcast, + const bool* output_is_broadcast, + nvinfer1::PluginFormat float_format, + int max_batct_size) TRT_NOEXCEPT override; + nvinfer1::IPluginV2Ext* clone() const TRT_NOEXCEPT override; + + private: + template + int enqueue_impl(int batch_size, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream); + nvinfer1::Weights copyToDevice(const void* hostData, size_t count); + void serializeFromDevice(void** hostBuffer, + const nvinfer1::Weights& deviceWeights) const; + nvinfer1::Weights deserializeToDevice(const void** hostBuffer, size_t count); + + nvinfer1::DataType data_type_; + nvinfer1::Weights weights_; + std::vector kernel_dims_; + std::vector strides_; + std::vector paddings_; + std::vector dilations_; + int groups_; + int deformable_groups_; + int im2col_step_; + std::string namespace_; + + std::vector input_dim_; + std::vector offset_dim_; + std::vector mask_dim_; + std::vector output_dim_; + + cublasHandle_t cublasHandle_; +}; + +class DeformableConvPluginCreator : public nvinfer1::IPluginCreator { + public: + DeformableConvPluginCreator(); + ~DeformableConvPluginCreator() override = default; + + void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override; + const char* getPluginNamespace() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override; + + nvinfer1::IPluginV2Ext* createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override; + nvinfer1::IPluginV2Ext* deserializePlugin( + const char* name, const void* serial_data, + size_t serial_length) TRT_NOEXCEPT override; + + private: + std::string namespace_; + nvinfer1::PluginFieldCollection field_collection_; +}; + +REGISTER_TRT_PLUGIN_V2(DeformableConvPluginCreator); + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/infer_ut/test_ppyolov2_r50vd.cc b/paddle/fluid/inference/tests/infer_ut/test_ppyolov2_r50vd.cc index 9689ec20956a1746c1f5a15d13722a3f68399da0..67b0c5ca17c2faa2a4c83dd934a246e606be165c 100644 --- a/paddle/fluid/inference/tests/infer_ut/test_ppyolov2_r50vd.cc +++ b/paddle/fluid/inference/tests/infer_ut/test_ppyolov2_r50vd.cc @@ -73,7 +73,7 @@ TEST(tensorrt_tester_ppyolov2_r50vd, multi_thread2_trt_fp32_bz1) { FLAGS_modeldir + "/model.pdiparams"); config.EnableUseGpu(100, 0); config.EnableTensorRtEngine( - 1 << 20, 2, 10, paddle_infer::PrecisionType::kFloat32, false, false); + 1 << 28, 2, 10, paddle_infer::PrecisionType::kFloat32, false, false); LOG(INFO) << config.Summary(); // get groudtruth by disbale ir paddle_infer::services::PredictorPool pred_pool_no_ir(config_no_ir, 1); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_deformable_conv.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_deformable_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..9d29034d7fe18d069dd7bc3b4651a4c97d13976a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_deformable_conv.py @@ -0,0 +1,181 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + + +class TrtConvertDeformableConvTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + inputs = program_config.inputs + weights = program_config.weights + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + if inputs['input_data'].shape[1] != weights['filter_data'].shape[ + 1] * attrs[0]['groups']: + return False + + return True + + def sample_program_configs(self): + def compute_output_size(input_size: List[int], + kernel_sizes: List[int], + attrs: List[Dict[str, Any]]): + strides = attrs[0]['strides'] + paddings = attrs[0]['paddings'] + dilations = attrs[0]['dilations'] + output_size = [] + for i, k, s, p, d in zip(input_size, kernel_sizes, strides, + paddings, dilations): + k = d * (k - 1) + 1 + output_size.append((i + 2 * p - k) // s + 1) + return output_size + + def generate_input1(batch: int, + input_size: List[int], + kernel_sizes: List[int], + attrs: List[Dict[str, Any]]): + return np.random.random([batch, 3] + input_size).astype(np.float32) + + def generate_offset1(batch: int, + input_size: List[int], + kernel_sizes: List[int], + attrs: List[Dict[str, Any]]): + output_size = compute_output_size(input_size, kernel_sizes, attrs) + return np.random.random([batch, 2 * np.prod(kernel_sizes)] + + output_size).astype(np.float32) + + def generate_mask1(batch: int, + input_size: List[int], + kernel_sizes: List[int], + attrs: List[Dict[str, Any]]): + output_size = compute_output_size(input_size, kernel_sizes, attrs) + return np.random.random([batch, np.prod(kernel_sizes)] + + output_size).astype(np.float32) + + def generate_filter1(batch: int, + input_size: List[int], + kernel_sizes: List[int], + attrs: List[Dict[str, Any]]): + return np.random.random([6, 3] + kernel_sizes).astype(np.float32) + + for batch in [1, ]: + for input_size in [[32, 32]]: + for kernel_sizes in [[3, 3]]: + for strides in [[1, 1], [2, 2]]: + for paddings in [[1, 1], [0, 2]]: + for groups in [1, ]: + for dilations in [[1, 1], [2, 2]]: + dics = [{ + "strides": strides, + "paddings": paddings, + "groups": groups, + "dilations": dilations, + "deformable_groups": 1, + "im2col_step": 1 + }] + + ops_config = [{ + "op_type": "deformable_conv", + "op_inputs": { + "Input": ["input_data"], + "Offset": ["offset_data"], + "Mask": ["mask_data"], + "Filter": ["filter_data"] + }, + "op_outputs": { + "Output": ["output_data"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "filter_data": + TensorConfig(data_gen=partial( + generate_filter1, batch, input_size, + kernel_sizes, dics)) + }, + inputs={ + "input_data": + TensorConfig(data_gen=partial( + generate_input1, batch, input_size, + kernel_sizes, dics)), + "offset_data": + TensorConfig(data_gen=partial( + generate_offset1, batch, input_size, + kernel_sizes, dics)), + "mask_data": TensorConfig( + data_gen=partial( + generate_mask1, batch, + input_size, kernel_sizes, dics)) + }, + outputs=["output_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + # TODO: This is just the example, need to be fixed. + if len(attrs[0]['paddings']) == 4: + return 1, 2 + else: + return 1, 2 + + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + + def add_skip_trt_case(self): + def teller1(program_config, predictor_config): + if len(program_config.ops[0].attrs["strides"]) != 2: + return False + + return True + + self.add_skip_case( + teller1, SkipReasons.TRT_NOT_IMPLEMENTED, + "In deformable conv, length of Attr(strides) should be 2.") + + def test(self): + self.trt_param.workspace_size = 1 << 28 + self.add_skip_trt_case() + self.run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_deformable_conv.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_deformable_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..508095fb801757ead30400527b781ecf925d0046 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_deformable_conv.py @@ -0,0 +1,95 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import PassVersionChecker +from paddle.fluid.core import AnalysisConfig + + +class TRTDeformableConvTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + input = fluid.data( + name='input', shape=self.input_size, dtype=self.dtype) + offset = fluid.data( + name='offset', shape=self.offset_size, dtype=self.dtype) + mask = fluid.data( + name='mask', shape=self.mask_size, dtype=self.dtype) + + output = fluid.layers.deformable_conv( + input, + offset, + mask, + self.num_filters, + self.filter_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilations, + groups=self.groups, + deformable_groups=self.deformable_groups, + im2col_step=self.im2col_step) + + self.feeds = { + 'input': np.random.random(self.input_size).astype(self.dtype), + 'offset': np.random.random(self.offset_size).astype(self.dtype), + 'mask': np.random.random(self.mask_size).astype(self.dtype) + } + self.enable_trt = True + dtype = AnalysisConfig.Precision.Float32 + if self.dtype == 'float16': + dtype = AnalysisConfig.Precision.Half + self.trt_parameters = TRTDeformableConvTest.TensorRTParam( + 1 << 30, self.bs, 0, dtype, False, False) + self.fetch_list = [output] + + def set_params(self): + self.groups = 1 + self.padding = [1, 1] + self.dilations = [1, 1] + self.stride = [1, 1] + self.im2col_step = 1 + self.deformable_groups = 1 + + self.bs = 2 + self.input_size = [self.bs, 8, 4, 4] + self.num_filters = 8 + self.filter_size = 3 + offset_c = 2 * self.deformable_groups * self.filter_size * self.filter_size + mask_c = self.deformable_groups * self.filter_size * self.filter_size + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + self.mask_size = [ + self.input_size[0], mask_c, self.input_size[2], self.input_size[3] + ] + + self.dtype = 'float32' + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +if __name__ == "__main__": + unittest.main()