未验证 提交 dfa242e4 编写于 作者: J JingZhuangzhuang 提交者: GitHub

fix trt convert conv2d skip (#38999)

* fix trt convert conv2d skip

* fix trt convert conv2d skip
上级 27f8460a
...@@ -76,12 +76,17 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, ...@@ -76,12 +76,17 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("dilations")); BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("dilations"));
const std::vector<int> strides = const std::vector<int> strides =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("strides")); BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("strides"));
const std::vector<int> paddings = std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("paddings")); BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("paddings"));
std::string padding_algorithm = "EXPLICIT"; std::string padding_algorithm = "EXPLICIT";
if (op_desc.HasAttr("padding_algorithm")) if (op_desc.HasAttr("padding_algorithm"))
padding_algorithm = padding_algorithm =
BOOST_GET_CONST(std::string, op_desc.GetAttr("padding_algorithm")); BOOST_GET_CONST(std::string, op_desc.GetAttr("padding_algorithm"));
if (padding_algorithm == "VALID") {
for (size_t i = 0; i < paddings.size(); i++) {
paddings[i] = 0;
}
}
nvinfer1::DimsHW nv_ksize(filter_h, filter_w); nvinfer1::DimsHW nv_ksize(filter_h, filter_w);
nvinfer1::DimsHW nv_dilations(dilations[0], dilations[1]); nvinfer1::DimsHW nv_dilations(dilations[0], dilations[1]);
...@@ -139,6 +144,8 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, ...@@ -139,6 +144,8 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
layer->setNbGroups(groups); layer->setNbGroups(groups);
if (padding_algorithm == "SAME") { if (padding_algorithm == "SAME") {
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
nv_dilations.d[0] = 1;
nv_dilations.d[1] = 1;
} }
// set dilations // set dilations
fset_dilation(layer, nv_dilations); fset_dilation(layer, nv_dilations);
......
...@@ -271,36 +271,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -271,36 +271,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false; return false;
} }
if (desc.HasAttr("padding_algorithm")) {
auto padding_algorithm =
BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm"));
if (padding_algorithm == "VALID") {
return false;
}
if (padding_algorithm == "SAME") {
if (desc.HasAttr("dilations")) {
const std::vector<int> dilations =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("dilations"));
if (dilations[0] != 1 || dilations[1] != 1) {
VLOG(3) << "In Same mode, Dilations must be (1, 1) for "
"tensorRT, but given ("
<< dilations[0] << ", " << dilations[1] << ")";
return false;
}
}
}
}
if (use_no_calib_int8) {
if (desc.HasAttr("padding_algorithm")) {
auto padding_algorithm =
BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm"));
if (padding_algorithm == "SAME") {
return false;
}
}
}
if (desc.HasAttr("enable_int8")) { if (desc.HasAttr("enable_int8")) {
if (op_type == "conv2d" || op_type == "conv2d_fusion") { if (op_type == "conv2d" || op_type == "conv2d_fusion") {
if (!desc.HasAttr("Input_scale")) { if (!desc.HasAttr("Input_scale")) {
......
...@@ -34,6 +34,12 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest): ...@@ -34,6 +34,12 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest):
1] * attrs[0]['groups']: 1] * attrs[0]['groups']:
return False return False
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000:
if attrs[0]['padding_algorithm'] == 'SAME' and (
attrs[0]['strides'][0] > 1 or attrs[0]['strides'][1] > 1):
return False
return True return True
def sample_program_configs(self): def sample_program_configs(self):
...@@ -68,39 +74,27 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest): ...@@ -68,39 +74,27 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest):
"data_format": data_format "data_format": data_format
}, {}] }, {}]
if padding_algorithm == 'EXPLICIT': ops_config = [{
ops_config = [{ "op_type": "conv2d",
"op_type": "conv2d", "op_inputs": {
"op_inputs": { "Input": ["input_data"],
"Input": ["input_data"], "Filter": ["conv2d_weight"]
"Filter": ["conv2d_weight"] },
}, "op_outputs": {
"op_outputs": { "Output": ["conv_output_data"]
"Output": ["conv_output_data"] },
}, "op_attrs": dics[0]
"op_attrs": dics[0] }, {
}, { "op_type": "relu",
"op_type": "relu", "op_inputs": {
"op_inputs": { "X": ["conv_output_data"]
"X": ["conv_output_data"] },
}, "op_outputs": {
"op_outputs": { "Out": ["output_data"]
"Out": ["output_data"] },
}, "op_attrs": dics[1]
"op_attrs": dics[1] }]
}]
else:
ops_config = [{
"op_type": "conv2d",
"op_inputs": {
"Input": ["input_data"],
"Filter": ["conv2d_weight"]
},
"op_outputs": {
"Output": ["output_data"]
},
"op_attrs": dics[0]
}]
ops = self.generate_op_config(ops_config) ops = self.generate_op_config(ops_config)
program_config = ProgramConfig( program_config = ProgramConfig(
...@@ -188,7 +182,6 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest): ...@@ -188,7 +182,6 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest):
attrs, False), (1e-5, 1e-5) attrs, False), (1e-5, 1e-5)
# for dynamic_shape # for dynamic_shape
generate_dynamic_shape(attrs) generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(attrs, yield self.create_inference_config(), generate_trt_nodes_num(attrs,
...@@ -200,25 +193,10 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest): ...@@ -200,25 +193,10 @@ class TrtConvertConv2dTest(TrtLayerAutoScanTest):
yield self.create_inference_config(), generate_trt_nodes_num( yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), (1e-5, 1e-5) attrs, True), (1e-5, 1e-5)
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if program_config.ops[0].attrs[
'padding_algorithm'] == "SAME" or program_config.ops[
0].attrs['padding_algorithm'] == "VALID":
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"When padding_algorithm is 'SAME' or 'VALID', Trt dose not support. In this case, trt build error is caused by scale op."
)
def test(self): def test(self):
self.add_skip_trt_case()
self.run_test() self.run_test()
def test_quant(self): def test_quant(self):
self.add_skip_trt_case()
self.run_test(quant=True) self.run_test(quant=True)
......
...@@ -37,6 +37,12 @@ class TrtConvertConv2dFusionTest(TrtLayerAutoScanTest): ...@@ -37,6 +37,12 @@ class TrtConvertConv2dFusionTest(TrtLayerAutoScanTest):
if attrs[0]['groups'] <= 1: if attrs[0]['groups'] <= 1:
return False return False
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000:
if attrs[0]['padding_algorithm'] == 'SAME' and (
attrs[0]['strides'][0] > 1 or attrs[0]['strides'][1] > 1):
return False
return True return True
def sample_program_configs(self): def sample_program_configs(self):
...@@ -184,25 +190,10 @@ class TrtConvertConv2dFusionTest(TrtLayerAutoScanTest): ...@@ -184,25 +190,10 @@ class TrtConvertConv2dFusionTest(TrtLayerAutoScanTest):
yield self.create_inference_config(), generate_trt_nodes_num( yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), (1e-5, 1e-5) attrs, True), (1e-5, 1e-5)
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if program_config.ops[0].attrs[
'padding_algorithm'] == "SAME" or program_config.ops[
0].attrs['padding_algorithm'] == "VALID":
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"When padding_algorithm is 'SAME' or 'VALID', Trt dose not support. In this case, trt build error is caused by scale op."
)
def test(self): def test(self):
self.add_skip_trt_case()
self.run_test() self.run_test()
def test_quant(self): def test_quant(self):
self.add_skip_trt_case()
self.run_test(quant=True) self.run_test(quant=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册