未验证 提交 8c3decd8 编写于 作者: W wangxinxin08 提交者: GitHub

add dcnv2 trt plugin (#36612)

* add dcnv2 plugin
上级 d6b1beb0
......@@ -1415,6 +1415,7 @@ USE_TRT_CONVERTER(tile);
USE_TRT_CONVERTER(conv3d);
USE_TRT_CONVERTER(conv3d_transpose);
USE_TRT_CONVERTER(mish);
USE_TRT_CONVERTER(deformable_conv);
USE_TRT_CONVERTER(pool3d)
#endif
......
......@@ -20,6 +20,7 @@ nv_library(tensorrt_converter
mish_op.cc
nearest_interp_v2_op.cc
pool3d_op.cc
deformable_conv_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cstdio>
#include <vector>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
class DeformableConvOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a deformable conv op to tensorrt plugin";
framework::OpDesc op_desc(op, nullptr);
std::string input_name = op_desc.Input("Input").front();
std::string offset_name = op_desc.Input("Offset").front();
std::string mask_name = op_desc.Input("Mask").front();
std::string filter_name = op_desc.Input("Filter").front();
auto* input_tensor = engine_->GetITensor(input_name);
auto* offset_tensor = engine_->GetITensor(offset_name);
auto* mask_tensor = engine_->GetITensor(mask_name);
auto* filter_var = scope.FindVar(filter_name);
auto* filter_tensor = filter_var->GetMutable<framework::LoDTensor>();
float* filter_data =
engine_->GetWeightCPUData(filter_name, filter_tensor, false);
const int c_o = filter_tensor->dims()[0];
const int c_i = filter_tensor->dims()[1];
const int k_h = filter_tensor->dims()[2];
const int k_w = filter_tensor->dims()[3];
std::vector<int> kernel_dims = {c_o, c_i, k_h, k_w};
auto strides =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("strides"));
auto paddings =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("paddings"));
auto dilations =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("dilations"));
auto groups = BOOST_GET_CONST(int, op_desc.GetAttr("groups"));
auto deformable_groups =
BOOST_GET_CONST(int, op_desc.GetAttr("deformable_groups"));
auto im2col_step = BOOST_GET_CONST(int, op_desc.GetAttr("im2col_step"));
nvinfer1::Weights weights;
weights.count = filter_tensor->numel();
if (engine_->WithFp16()) {
auto half_filter_data = new half[filter_tensor->numel()];
for (int i = 0; i < filter_tensor->numel(); i++) {
half_filter_data[i] = static_cast<half>(filter_data[i]);
}
weights.type = nvinfer1::DataType::kHALF;
weights.values = half_filter_data;
} else {
weights.type = nvinfer1::DataType::kFLOAT;
weights.values = filter_data;
}
auto* deformable_conv_plugin = new plugin::DeformableConvPlugin(
engine_->WithFp16() ? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT,
weights, kernel_dims, strides, paddings, dilations, groups,
deformable_groups, im2col_step);
std::vector<nvinfer1::ITensor*> deformable_conv_inputs;
deformable_conv_inputs.push_back(input_tensor);
deformable_conv_inputs.push_back(offset_tensor);
deformable_conv_inputs.push_back(mask_tensor);
auto* deformable_conv_layer = engine_->network()->addPluginV2(
deformable_conv_inputs.data(), deformable_conv_inputs.size(),
*deformable_conv_plugin);
std::vector<std::string> output_names;
output_names.push_back(op_desc.Output("Output").front());
RreplenishLayerAndOutput(deformable_conv_layer, "deformable_conv",
output_names, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(deformable_conv, DeformableConvOpConverter);
......@@ -143,7 +143,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"conv3d_transpose",
"mish",
"nearest_interp_v2",
"pool3d"};
"pool3d",
"deformable_conv"};
};
bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
......@@ -332,6 +333,51 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
#endif
}
if (op_type == "deformable_conv") {
if (with_dynamic_shape) {
VLOG(3) << "Deformable conv trt plugin does not support dynamic shape";
return false;
}
auto* block = desc.Block();
auto input_name = desc.Input("Input")[0];
auto* input_desc = block->FindVar(input_name);
const auto input_shape = input_desc->GetShape();
if (input_shape.size() != 4) {
VLOG(3) << "Input of deformable conv should be 4-D Tensor, but got "
<< input_shape.size();
return false;
}
auto filter_name = desc.Input("Filter")[0];
auto* filter_desc = block->FindVar(filter_name);
const auto filter_shape = filter_desc->GetShape();
int groups = BOOST_GET_CONST(int, desc.GetAttr("groups"));
if (input_shape[1] != filter_shape[1] * groups) {
VLOG(3) << "The number of input channels should be equal to filter "
<< "channels * groups. But got input channels "
<< input_shape[1] << "filter channels " << filter_shape[1];
return false;
}
const std::vector<int> strides =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("strides"));
if (strides.size() != 2) {
VLOG(3) << "The size of strides should be 2, but got "
<< strides.size();
return false;
}
const std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
if (paddings.size() != 2) {
VLOG(3) << "The size of paddings shoule be 2, but got "
<< paddings.size();
return false;
}
}
if (op_type == "matmul") {
auto* block = desc.Block();
if (block == nullptr) {
......
......@@ -11,6 +11,7 @@ nv_library(tensorrt_plugin
gather_nd_op_plugin.cu
mish_op_plugin.cu
pool3d_op_plugin.cu
deformable_conv_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstdio>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class DeformableConvPlugin : public nvinfer1::IPluginV2Ext {
public:
explicit DeformableConvPlugin(
const nvinfer1::DataType data_type, const nvinfer1::Weights& weights,
const std::vector<int>& kernel_dims, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& dilations,
const int groups, const int deformable_groups, const int im2col_step);
explicit DeformableConvPlugin(
const nvinfer1::DataType data_type, const nvinfer1::Weights& weights,
const std::vector<int>& kernel_dims, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& dilations,
const int groups, const int deformable_groups, const int im2col_step,
const std::vector<int>& input_dim, const std::vector<int>& offset_dim,
const std::vector<int>& mask_dim, const std::vector<int>& output_dim);
DeformableConvPlugin(const void* data, size_t length);
~DeformableConvPlugin() override;
const char* getPluginType() const TRT_NOEXCEPT override;
const char* getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
int nb_input_dims) TRT_NOEXCEPT override;
bool supportsFormat(nvinfer1::DataType type, nvinfer1::TensorFormat format)
const TRT_NOEXCEPT override;
size_t getWorkspaceSize(int max_batch_size) const TRT_NOEXCEPT override;
#if IS_TRT_VERSION_LT(8000)
int enqueue(int batch_size, const void* const* inputs, void** outputs,
#else
int enqueue(int batch_size, const void* const* inputs, void* const* outputs,
#endif
void* workspace, cudaStream_t stream) TRT_NOEXCEPT override;
int initialize() TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void* buffer) const TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override;
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override;
const char* getPluginNamespace() const TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType* input_type,
int nb_inputs) const TRT_NOEXCEPT override;
bool isOutputBroadcastAcrossBatch(int output_index,
const bool* input_is_broadcast,
int nb_inputs) const TRT_NOEXCEPT override;
bool canBroadcastInputAcrossBatch(int input_index) const
TRT_NOEXCEPT override;
void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext,
nvinfer1::IGpuAllocator* gpuAllocator)
TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::Dims* input_dims, int nb_inputs,
const nvinfer1::Dims* output_dims, int nb_outputs,
const nvinfer1::DataType* input_types,
const nvinfer1::DataType* output_types,
const bool* input_is_broadcast,
const bool* output_is_broadcast,
nvinfer1::PluginFormat float_format,
int max_batct_size) TRT_NOEXCEPT override;
nvinfer1::IPluginV2Ext* clone() const TRT_NOEXCEPT override;
private:
template <typename T>
int enqueue_impl(int batch_size, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream);
nvinfer1::Weights copyToDevice(const void* hostData, size_t count);
void serializeFromDevice(void** hostBuffer,
const nvinfer1::Weights& deviceWeights) const;
nvinfer1::Weights deserializeToDevice(const void** hostBuffer, size_t count);
nvinfer1::DataType data_type_;
nvinfer1::Weights weights_;
std::vector<int> kernel_dims_;
std::vector<int> strides_;
std::vector<int> paddings_;
std::vector<int> dilations_;
int groups_;
int deformable_groups_;
int im2col_step_;
std::string namespace_;
std::vector<int> input_dim_;
std::vector<int> offset_dim_;
std::vector<int> mask_dim_;
std::vector<int> output_dim_;
cublasHandle_t cublasHandle_;
};
class DeformableConvPluginCreator : public nvinfer1::IPluginCreator {
public:
DeformableConvPluginCreator();
~DeformableConvPluginCreator() override = default;
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override;
const char* getPluginNamespace() const TRT_NOEXCEPT override;
const char* getPluginName() const TRT_NOEXCEPT override;
const char* getPluginVersion() const TRT_NOEXCEPT override;
const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override;
nvinfer1::IPluginV2Ext* createPlugin(
const char* name,
const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override;
nvinfer1::IPluginV2Ext* deserializePlugin(
const char* name, const void* serial_data,
size_t serial_length) TRT_NOEXCEPT override;
private:
std::string namespace_;
nvinfer1::PluginFieldCollection field_collection_;
};
REGISTER_TRT_PLUGIN_V2(DeformableConvPluginCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -73,7 +73,7 @@ TEST(tensorrt_tester_ppyolov2_r50vd, multi_thread2_trt_fp32_bz1) {
FLAGS_modeldir + "/model.pdiparams");
config.EnableUseGpu(100, 0);
config.EnableTensorRtEngine(
1 << 20, 2, 10, paddle_infer::PrecisionType::kFloat32, false, false);
1 << 28, 2, 10, paddle_infer::PrecisionType::kFloat32, false, false);
LOG(INFO) << config.Summary();
// get groudtruth by disbale ir
paddle_infer::services::PredictorPool pred_pool_no_ir(config_no_ir, 1);
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
class TrtConvertDeformableConvTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
inputs = program_config.inputs
weights = program_config.weights
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
if inputs['input_data'].shape[1] != weights['filter_data'].shape[
1] * attrs[0]['groups']:
return False
return True
def sample_program_configs(self):
def compute_output_size(input_size: List[int],
kernel_sizes: List[int],
attrs: List[Dict[str, Any]]):
strides = attrs[0]['strides']
paddings = attrs[0]['paddings']
dilations = attrs[0]['dilations']
output_size = []
for i, k, s, p, d in zip(input_size, kernel_sizes, strides,
paddings, dilations):
k = d * (k - 1) + 1
output_size.append((i + 2 * p - k) // s + 1)
return output_size
def generate_input1(batch: int,
input_size: List[int],
kernel_sizes: List[int],
attrs: List[Dict[str, Any]]):
return np.random.random([batch, 3] + input_size).astype(np.float32)
def generate_offset1(batch: int,
input_size: List[int],
kernel_sizes: List[int],
attrs: List[Dict[str, Any]]):
output_size = compute_output_size(input_size, kernel_sizes, attrs)
return np.random.random([batch, 2 * np.prod(kernel_sizes)] +
output_size).astype(np.float32)
def generate_mask1(batch: int,
input_size: List[int],
kernel_sizes: List[int],
attrs: List[Dict[str, Any]]):
output_size = compute_output_size(input_size, kernel_sizes, attrs)
return np.random.random([batch, np.prod(kernel_sizes)] +
output_size).astype(np.float32)
def generate_filter1(batch: int,
input_size: List[int],
kernel_sizes: List[int],
attrs: List[Dict[str, Any]]):
return np.random.random([6, 3] + kernel_sizes).astype(np.float32)
for batch in [1, ]:
for input_size in [[32, 32]]:
for kernel_sizes in [[3, 3]]:
for strides in [[1, 1], [2, 2]]:
for paddings in [[1, 1], [0, 2]]:
for groups in [1, ]:
for dilations in [[1, 1], [2, 2]]:
dics = [{
"strides": strides,
"paddings": paddings,
"groups": groups,
"dilations": dilations,
"deformable_groups": 1,
"im2col_step": 1
}]
ops_config = [{
"op_type": "deformable_conv",
"op_inputs": {
"Input": ["input_data"],
"Offset": ["offset_data"],
"Mask": ["mask_data"],
"Filter": ["filter_data"]
},
"op_outputs": {
"Output": ["output_data"]
},
"op_attrs": dics[0]
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"filter_data":
TensorConfig(data_gen=partial(
generate_filter1, batch, input_size,
kernel_sizes, dics))
},
inputs={
"input_data":
TensorConfig(data_gen=partial(
generate_input1, batch, input_size,
kernel_sizes, dics)),
"offset_data":
TensorConfig(data_gen=partial(
generate_offset1, batch, input_size,
kernel_sizes, dics)),
"mask_data": TensorConfig(
data_gen=partial(
generate_mask1, batch,
input_size, kernel_sizes, dics))
},
outputs=["output_data"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
# TODO: This is just the example, need to be fixed.
if len(attrs[0]['paddings']) == 4:
return 1, 2
else:
return 1, 2
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if len(program_config.ops[0].attrs["strides"]) != 2:
return False
return True
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"In deformable conv, length of Attr(strides) should be 2.")
def test(self):
self.trt_param.workspace_size = 1 << 28
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
from paddle.fluid.core import AnalysisConfig
class TRTDeformableConvTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
input = fluid.data(
name='input', shape=self.input_size, dtype=self.dtype)
offset = fluid.data(
name='offset', shape=self.offset_size, dtype=self.dtype)
mask = fluid.data(
name='mask', shape=self.mask_size, dtype=self.dtype)
output = fluid.layers.deformable_conv(
input,
offset,
mask,
self.num_filters,
self.filter_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilations,
groups=self.groups,
deformable_groups=self.deformable_groups,
im2col_step=self.im2col_step)
self.feeds = {
'input': np.random.random(self.input_size).astype(self.dtype),
'offset': np.random.random(self.offset_size).astype(self.dtype),
'mask': np.random.random(self.mask_size).astype(self.dtype)
}
self.enable_trt = True
dtype = AnalysisConfig.Precision.Float32
if self.dtype == 'float16':
dtype = AnalysisConfig.Precision.Half
self.trt_parameters = TRTDeformableConvTest.TensorRTParam(
1 << 30, self.bs, 0, dtype, False, False)
self.fetch_list = [output]
def set_params(self):
self.groups = 1
self.padding = [1, 1]
self.dilations = [1, 1]
self.stride = [1, 1]
self.im2col_step = 1
self.deformable_groups = 1
self.bs = 2
self.input_size = [self.bs, 8, 4, 4]
self.num_filters = 8
self.filter_size = 3
offset_c = 2 * self.deformable_groups * self.filter_size * self.filter_size
mask_c = self.deformable_groups * self.filter_size * self.filter_size
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
self.mask_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3]
]
self.dtype = 'float32'
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册