未验证 提交 42847d2e 编写于 作者: W wenbin 提交者: GitHub

conv3d (#35507)

* conv3d

* remove const_cast

* modify ut

* disable dynamic shape for trt6.0

* remove trt5
上级 512329b0
......@@ -1257,6 +1257,8 @@ USE_TRT_CONVERTER(reduce_sum);
USE_TRT_CONVERTER(gather_nd);
USE_TRT_CONVERTER(reduce_mean);
USE_TRT_CONVERTER(tile);
USE_TRT_CONVERTER(conv3d);
USE_TRT_CONVERTER(conv3d_transpose);
#endif
namespace paddle_infer {
......
......@@ -16,6 +16,7 @@ nv_library(tensorrt_converter
reduce_op.cc
gather_nd_op.cc
tile_op.cc
conv3d_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
template <typename RegistFunc, typename SetDilationFunc>
void ConvertConv3d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode,
RegistFunc fadd_layer, SetDilationFunc fset_dilation,
const std::string& name) {
VLOG(3) << "convert a fluid " << name << " op to tensorrt layer without bias";
framework::OpDesc op_desc(op, nullptr);
auto* X = engine->GetITensor(op_desc.Input("Input").front());
std::string filter_var_name = op_desc.Input("Filter").front();
auto* Y_v = scope.FindVar(filter_var_name);
PADDLE_ENFORCE_NOT_NULL(
Y_v, platform::errors::NotFound(
"Can not find %s presistale var in scope.", filter_var_name));
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
float* weight_data = nullptr;
bool enable_int8 = op_desc.HasAttr("enable_int8");
if (enable_int8) {
float in_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127;
auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t,
true, weight_scale);
engine->SetTensorDynamicRange(X, in_scale);
} else {
weight_data =
engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, false);
}
PADDLE_ENFORCE_EQ(Y_t->dims().size(), 5UL,
platform::errors::InvalidArgument(
"The conv3d filter's dims size should be 5, but got %d",
Y_t->dims().size()));
const int n_output = Y_t->dims()[0];
const int n_input = Y_t->dims()[1];
const int filter_d = Y_t->dims()[2];
const int filter_h = Y_t->dims()[3];
const int filter_w = Y_t->dims()[4];
const int groups = BOOST_GET_CONST(int, op_desc.GetAttr("groups"));
const std::vector<int> dilations =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("dilations"));
const std::vector<int> strides =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("strides"));
const std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("paddings"));
std::string padding_algorithm = "EXPLICIT";
if (op_desc.HasAttr("padding_algorithm"))
padding_algorithm =
BOOST_GET_CONST(std::string, op_desc.GetAttr("padding_algorithm"));
nvinfer1::Dims3 nv_ksize(filter_d, filter_h, filter_w);
nvinfer1::Dims3 nv_dilations(dilations[0], dilations[1], dilations[2]);
nvinfer1::Dims3 nv_strides(strides[0], strides[1], strides[2]);
nvinfer1::Dims3 nv_paddings(paddings[0], paddings[1], paddings[2]);
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<size_t>(Y_t->numel())};
float* bias_data = nullptr;
size_t bias_size = 0;
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data), bias_size};
// In conv3d_transpose output channels = filter_dims[1] * groups
auto* layer = (op_desc.Type() == "conv3d_transpose")
? fadd_layer(X, n_input * groups, nv_ksize, weight, bias)
: fadd_layer(X, n_output, nv_ksize, weight, bias);
PADDLE_ENFORCE_NOT_NULL(
layer, platform::errors::Fatal("TensorRT create conv3d/conv3d_transpose"
" layer failed."));
layer->setStrideNd(nv_strides);
layer->setPaddingNd(nv_paddings);
layer->setNbGroups(groups);
if (padding_algorithm == "SAME") {
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
}
// set dilations
fset_dilation(layer, nv_dilations);
auto output_name = op_desc.Output("Output").front();
layer->setName((name + " (Output: " + output_name + ")").c_str());
layer->getOutput(0)->setName(output_name.c_str());
engine->SetITensor(output_name, layer->getOutput(0));
if (test_mode) {
engine->DeclareOutput(output_name);
}
}
class Conv3dOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
ConvertConv3d(
engine_, op, scope, test_mode,
[&](nvinfer1::ITensor* inputs, int n_output, /* Conv output maps */
nvinfer1::Dims& ksize, TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) -> nvinfer1::IConvolutionLayer* {
auto* layer =
TRT_ENGINE_ADD_LAYER(engine_, ConvolutionNd, *inputs, n_output,
ksize, weight.get(), bias.get());
return layer;
},
[](nvinfer1::IConvolutionLayer* layer, nvinfer1::Dims& dilations) {
layer->setDilationNd(dilations);
},
"conv3d");
}
};
class Deconv3dOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
ConvertConv3d(
engine_, op, scope, test_mode,
[&](nvinfer1::ITensor* inputs, int n_output, /* Deconv input maps */
nvinfer1::Dims& ksize, TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) -> nvinfer1::IDeconvolutionLayer* {
auto* layer =
TRT_ENGINE_ADD_LAYER(engine_, DeconvolutionNd, *inputs, n_output,
ksize, weight.get(), bias.get());
return layer;
},
[](nvinfer1::IDeconvolutionLayer* layer, nvinfer1::Dims& dilations) {},
"conv3d_transpose");
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(conv3d, Conv3dOpConverter);
REGISTER_TRT_OP_CONVERTER(conv3d_transpose, Deconv3dOpConverter);
......@@ -76,11 +76,7 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<T>& shape, std::string input,
"TensorRT's tensor input requires at least 1 "
"dimensions, but input %s has %d dims.",
input, shape.size()));
PADDLE_ENFORCE_LE(shape.size(), 4UL,
platform::errors::InvalidArgument(
"TensorRT's tensor input requires at most 4 "
"dimensions, but input %s has %d dims.",
input, shape.size()));
auto ShapeStr = [](const std::vector<T>& shape) {
std::ostringstream os;
os << "[";
......@@ -103,6 +99,14 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<T>& shape, std::string input,
input, ShapeStr(shape)));
}
return nvinfer1::Dims3(shape[1], shape[2], shape[3]);
} else if (shape.size() == 5UL) {
if (shape[2] == -1 || shape[3] == -1 || shape[4] == -1) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The input [%s] shape of trt subgraph is %s, please enable "
"trt dynamic_shape mode by SetTRTDynamicShapeInfo.",
input, ShapeStr(shape)));
}
return nvinfer1::Dims4(shape[1], shape[2], shape[3], shape[4]);
} else if (shape.size() == 3UL) {
if (shape[1] == -1 || shape[2] == -1) {
PADDLE_THROW(platform::errors::InvalidArgument(
......
......@@ -90,51 +90,51 @@ struct SimpleOpTypeSetTeller : public Teller {
"elementwise_mul",
"conv2d_transpose",
"hard_swish"};
std::unordered_set<std::string> teller_set{
"mul",
"matmul",
"conv2d",
"conv2d_fusion",
"pool2d",
"relu",
"softmax",
"sigmoid",
"hard_swish",
"depthwise_conv2d",
"batch_norm",
"concat",
"tanh",
"pad",
"elementwise_add",
"elementwise_mul",
"dropout",
"prelu",
"conv2d_transpose",
"depthwise_conv2d_transpose",
"leaky_relu",
"fc",
"shuffle_channel",
"swish",
"split",
"instance_norm",
"gelu",
"layer_norm",
"scale",
"stack",
"transpose2",
"transpose",
"flatten2",
"flatten",
"gather",
"gather_nd",
"yolo_box",
"roi_align",
"affine_channel",
"nearest_interp",
"anchor_generator",
"reduce_sum",
"reduce_mean",
};
std::unordered_set<std::string> teller_set{"mul",
"matmul",
"conv2d",
"conv2d_fusion",
"pool2d",
"relu",
"softmax",
"sigmoid",
"hard_swish",
"depthwise_conv2d",
"batch_norm",
"concat",
"tanh",
"pad",
"elementwise_add",
"elementwise_mul",
"dropout",
"prelu",
"conv2d_transpose",
"depthwise_conv2d_transpose",
"leaky_relu",
"fc",
"shuffle_channel",
"swish",
"split",
"instance_norm",
"gelu",
"layer_norm",
"scale",
"stack",
"transpose2",
"transpose",
"flatten2",
"flatten",
"gather",
"gather_nd",
"yolo_box",
"roi_align",
"affine_channel",
"nearest_interp",
"anchor_generator",
"reduce_sum",
"reduce_mean",
"conv3d",
"conv3d_transpose"};
};
bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
......@@ -767,6 +767,65 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
#endif
if (op_type == "conv3d" || op_type == "conv3d_transpose") {
if (desc.HasAttr("padding_algorithm")) {
std::string padding_algorithm =
BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm"));
// trt error is arised if conv3d_transpose and SAME
if (op_type == "conv3d_transpose" && padding_algorithm == "SAME" &&
!with_dynamic_shape) {
return false;
}
}
#if !IS_TRT_VERSION_GE(7000)
// looks like some issues with trt6.0
if (with_dynamic_shape) {
return false;
}
#endif
std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
// conv3d and conv3d_transpose need padding check
if (paddings.size() > 3) return false;
if (desc.Input("Input").size() != 1) {
VLOG(3) << "TRT Conv3d expect 1 input, but got "
<< desc.Input("Input").size() << " input.";
return false;
}
if (desc.Input("Filter").size() != 1) {
VLOG(3) << "TRT Conv3d expect 1 filter, but got "
<< desc.Input("Filter").size() << " filter.";
return false;
}
if (op_type == "conv3d_transpose") {
if (!desc.HasAttr("dilations")) {
return false;
} else {
const std::vector<int> dilations =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("dilations"));
if (dilations[0] != 1 || dilations[1] != 1 || dilations[2] != 1) {
VLOG(3) << "In conv3d_transpose, Dilations must be (1, 1, 1) for "
"tensorRT, but given ("
<< dilations[0] << ", " << dilations[1] << ", "
<< dilations[2] << ")";
return false;
}
}
}
if (desc.Output("Output").size() != 1) {
VLOG(3) << "TRT Conv3d expect 1 output, but got "
<< desc.Output("Output").size() << " output.";
return false;
}
}
if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
}
......
......@@ -66,4 +66,6 @@ set_tests_properties(test_trt_tile_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_trt_fc_fuse_quant_dequant_pass PROPERTIES TIMEOUT 100)
set_tests_properties(test_trt_conv_quant_dequant_pass PROPERTIES TIMEOUT 100)
set_tests_properties(test_trt_matmul_quant_dequant PROPERTIES TIMEOUT 100)
set_tests_properties(test_trt_conv3d_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_trt_conv3d_transpose_op PROPERTIES TIMEOUT 60)
endif()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
from paddle.fluid.core import AnalysisConfig
class TensorRTSubgraphPassConv3dTest(InferencePassTest):
def setUp(self):
self.init_params()
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 6, 32, 32], dtype="float32")
conv_out = fluid.layers.conv3d(
input=data,
num_filters=self.conv_num_filters,
filter_size=self.conv_filter_size,
groups=self.conv_groups,
padding=self.conv_padding,
bias_attr=False,
use_cudnn=self.use_cudnn,
stride=self.stride,
act=None)
self.feeds = {
"data": np.random.random([1, 3, 6, 32, 32]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassConv3dTest.TensorRTParam(
1 << 30, 32, 1, self.precision, self.use_static, False)
self.fetch_list = [conv_out]
def init_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 3
self.conv_padding = [1, 1, 1]
self.use_cudnn = True
self.use_static = False
self.precision = AnalysisConfig.Precision.Float32
self.stride = 1
def set_params(self):
pass
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassConv3dValidPaddingTest(
TensorRTSubgraphPassConv3dTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 3
self.conv_padding = 'VALID'
class TensorRTSubgraphPassConv3dSamePaddingTest(TensorRTSubgraphPassConv3dTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 3
self.conv_padding = 'SAME'
class TensorRTSubgraphPassConv3dPaddingTest(TensorRTSubgraphPassConv3dTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 3
self.conv_padding = [2, 3, 3]
class TensorRTSubgraphPassConv3dStrideTest(TensorRTSubgraphPassConv3dTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 3
self.conv_padding = 'SAME'
self.stride = [1, 2, 2]
class DynamicShapeTensorRTSubgraphPassConv3dTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 6, -1, -1, -1], dtype="float32")
conv_out = fluid.layers.conv3d(
input=data,
num_filters=self.conv_num_filters,
filter_size=self.conv_filter_size,
groups=self.conv_groups,
padding=self.conv_padding,
bias_attr=False,
use_cudnn=self.use_cudnn,
stride=self.stride,
act=None)
self.feeds = {
"data": np.random.random([1, 6, 32, 32, 8]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = DynamicShapeTensorRTSubgraphPassConv3dTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.dynamic_shape_params = DynamicShapeTensorRTSubgraphPassConv3dTest.DynamicShapeParam(
{
"data": [1, 6, 8, 8, 8],
"conv3d_0.tmp_0": [1, 6, 8, 8, 4],
}, {
"data": [32, 6, 32, 32, 8],
"conv3d_0.tmp_0": [32, 6, 32, 32, 8],
}, {
"data": [16, 6, 16, 16, 8],
"conv3d_0.tmp_0": [16, 6, 16, 16, 8],
}, False)
self.fetch_list = [conv_out]
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 6
self.conv_padding = 'SAME'
self.use_cudnn = True
self.stride = [2, 2, 2]
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
from paddle.fluid.core import AnalysisConfig
class TensorRTSubgraphPassConv3dTransposeTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 4, 4, 32, 32], dtype="float32")
conv_out = fluid.layers.conv3d_transpose(
input=data,
num_filters=self.conv_num_filters,
filter_size=self.conv_filter_size,
groups=self.conv_groups,
padding=self.conv_padding,
bias_attr=False,
use_cudnn=self.use_cudnn,
stride=1,
act=None)
self.feeds = {
"data": np.random.random([1, 4, 4, 32, 32]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassConv3dTransposeTest.TensorRTParam(
1 << 30, 32, 1, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [conv_out]
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 1
self.conv_padding = [1, 1, 1]
self.use_cudnn = True
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassConv3dTransposeSamePaddingTest(
TensorRTSubgraphPassConv3dTransposeTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 1
self.conv_padding = 'VALID'
self.use_cudnn = True
class TensorRTSubgraphPassConv3dTransposeMultigroupTest(
TensorRTSubgraphPassConv3dTransposeTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 2
self.conv_padding = 'VALID'
self.use_cudnn = True
class DynamicShapeTensorRTSubgraphPassConv3dTransposeTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 6, -1, -1, -1], dtype="float32")
conv_out = fluid.layers.conv3d_transpose(
input=data,
num_filters=self.conv_num_filters,
filter_size=self.conv_filter_size,
groups=self.conv_groups,
padding=self.conv_padding,
bias_attr=False,
use_cudnn=self.use_cudnn,
stride=self.stride,
act=None)
self.feeds = {
"data": np.random.random([1, 6, 32, 32, 8]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = DynamicShapeTensorRTSubgraphPassConv3dTransposeTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.dynamic_shape_params = DynamicShapeTensorRTSubgraphPassConv3dTransposeTest.DynamicShapeParam(
{
"data": [1, 6, 8, 8, 8],
"conv3d_transpose_0.tmp_0": [1, 6, 8, 8, 1],
}, {
"data": [32, 6, 32, 32, 8],
"conv3d_transpose_0.tmp_0": [32, 6, 64, 64, 16],
}, {
"data": [16, 6, 16, 16, 8],
"conv3d_transpose_0.tmp_0": [16, 6, 16, 16, 8],
}, False)
self.fetch_list = [conv_out]
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 6
self.conv_padding = 'SAME'
self.use_cudnn = True
self.stride = [2, 2, 2]
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册