未验证 提交 fdbdef0e 编写于 作者: P Pei Yang 提交者: GitHub

fix conv2d_transpose trt bugs (#33242)

上级 29dc439a
...@@ -103,11 +103,18 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, ...@@ -103,11 +103,18 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data), bias_size}; static_cast<void*>(bias_data), bias_size};
auto* layer = fadd_layer(const_cast<nvinfer1::ITensor*>(X), n_output, n_input, // In conv2d_transpose and depthwise_conv2d_transpose,
nv_ksize, weight, bias); // output channels = filter_dims[1] * groups
PADDLE_ENFORCE_NOT_NULL(layer, auto* layer = (op_desc.Type() == "conv2d_transpose" ||
platform::errors::Fatal("TensorRT create conv2d" op_desc.Type() == "depthwise_conv2d_transpose")
" layer error.")); ? fadd_layer(const_cast<nvinfer1::ITensor*>(X),
n_input * groups, nv_ksize, weight, bias)
: fadd_layer(const_cast<nvinfer1::ITensor*>(X), n_output,
nv_ksize, weight, bias);
PADDLE_ENFORCE_NOT_NULL(
layer, platform::errors::Fatal("TensorRT create conv2d/conv2d_transpose"
" layer failed."));
layer->setStride(nv_strides); layer->setStride(nv_strides);
layer->setPadding(nv_paddings); layer->setPadding(nv_paddings);
layer->setNbGroups(groups); layer->setNbGroups(groups);
...@@ -134,7 +141,6 @@ class Conv2dOpConverter : public OpConverter { ...@@ -134,7 +141,6 @@ class Conv2dOpConverter : public OpConverter {
ConvertConv2d( ConvertConv2d(
engine_, op, scope, test_mode, engine_, op, scope, test_mode,
[&](nvinfer1::ITensor* inputs, int n_output, /* Conv output maps */ [&](nvinfer1::ITensor* inputs, int n_output, /* Conv output maps */
int n_input, /* Conv input maps */
nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight, nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) -> nvinfer1::IConvolutionLayer* { TensorRTEngine::Weight& bias) -> nvinfer1::IConvolutionLayer* {
auto* layer = auto* layer =
...@@ -156,7 +162,6 @@ class Deconv2dOpConverter : public OpConverter { ...@@ -156,7 +162,6 @@ class Deconv2dOpConverter : public OpConverter {
ConvertConv2d( ConvertConv2d(
engine_, op, scope, test_mode, engine_, op, scope, test_mode,
[&](nvinfer1::ITensor* inputs, int n_output, /* Deconv input maps */ [&](nvinfer1::ITensor* inputs, int n_output, /* Deconv input maps */
int n_input, /* Deconv output maps */
nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight, nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) -> nvinfer1::IDeconvolutionLayer* { TensorRTEngine::Weight& bias) -> nvinfer1::IDeconvolutionLayer* {
auto* layer = auto* layer =
......
...@@ -36,6 +36,7 @@ class TensorRTSubgraphPassConvTest(InferencePassTest): ...@@ -36,6 +36,7 @@ class TensorRTSubgraphPassConvTest(InferencePassTest):
groups=self.conv_groups, groups=self.conv_groups,
padding=self.conv_padding, padding=self.conv_padding,
bias_attr=False, bias_attr=False,
use_cudnn=self.use_cudnn,
act=None) act=None)
self.feeds = { self.feeds = {
"data": np.random.random([1, 6, 64, 64]).astype("float32"), "data": np.random.random([1, 6, 64, 64]).astype("float32"),
...@@ -50,6 +51,7 @@ class TensorRTSubgraphPassConvTest(InferencePassTest): ...@@ -50,6 +51,7 @@ class TensorRTSubgraphPassConvTest(InferencePassTest):
self.conv_filter_size = 6 self.conv_filter_size = 6
self.conv_groups = 3 self.conv_groups = 3
self.conv_padding = [1, 1] self.conv_padding = [1, 1]
self.use_cudnn = True
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
...@@ -65,6 +67,7 @@ class TensorRTSubgraphPassConvValidPaddingTest(TensorRTSubgraphPassConvTest): ...@@ -65,6 +67,7 @@ class TensorRTSubgraphPassConvValidPaddingTest(TensorRTSubgraphPassConvTest):
self.conv_filter_size = 6 self.conv_filter_size = 6
self.conv_groups = 3 self.conv_groups = 3
self.conv_padding = 'VALID' self.conv_padding = 'VALID'
self.use_cudnn = True
class TensorRTSubgraphPassConvSamePaddingTest(InferencePassTest): class TensorRTSubgraphPassConvSamePaddingTest(InferencePassTest):
...@@ -73,6 +76,7 @@ class TensorRTSubgraphPassConvSamePaddingTest(InferencePassTest): ...@@ -73,6 +76,7 @@ class TensorRTSubgraphPassConvSamePaddingTest(InferencePassTest):
self.conv_filter_size = 6 self.conv_filter_size = 6
self.conv_groups = 3 self.conv_groups = 3
self.conv_padding = 'SAME' self.conv_padding = 'SAME'
self.use_cudnn = True
class TensorRTSubgraphPassDepthwiseConvTest(TensorRTSubgraphPassConvTest): class TensorRTSubgraphPassDepthwiseConvTest(TensorRTSubgraphPassConvTest):
...@@ -81,6 +85,16 @@ class TensorRTSubgraphPassDepthwiseConvTest(TensorRTSubgraphPassConvTest): ...@@ -81,6 +85,16 @@ class TensorRTSubgraphPassDepthwiseConvTest(TensorRTSubgraphPassConvTest):
self.conv_filter_size = 6 self.conv_filter_size = 6
self.conv_groups = 6 self.conv_groups = 6
self.conv_padding = [1, 1] self.conv_padding = [1, 1]
self.use_cudnn = False
class TensorRTSubgraphPassDepthwiseConv2Test(TensorRTSubgraphPassConvTest):
def set_params(self):
self.conv_num_filters = 12
self.conv_filter_size = 6
self.conv_groups = 6
self.conv_padding = [1, 1]
self.use_cudnn = False
class TensorRTSubgraphPassConvTransposeTest(InferencePassTest): class TensorRTSubgraphPassConvTransposeTest(InferencePassTest):
...@@ -151,6 +165,16 @@ class TensorRTSubgraphPassConvTransposeMultiGroupTest( ...@@ -151,6 +165,16 @@ class TensorRTSubgraphPassConvTransposeMultiGroupTest(
self.use_cudnn = True self.use_cudnn = True
class TensorRTSubgraphPassConvTranspose2Test(
TensorRTSubgraphPassConvTransposeTest):
def set_params(self):
self.conv_num_filters = 12
self.conv_filter_size = 4
self.conv_groups = 6
self.conv_padding = [1, 1]
self.use_cudnn = False
class TensorRTSubgraphPassDepthwiseConvTransposeTest( class TensorRTSubgraphPassDepthwiseConvTransposeTest(
TensorRTSubgraphPassConvTransposeTest): TensorRTSubgraphPassConvTransposeTest):
def set_params(self): def set_params(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册