未验证 提交 56b1ccb7 编写于 作者: B baoachun 提交者: GitHub

remove input dim check in op_teller and update ut (#37097) (#37773)

上级 fe43beed
...@@ -201,12 +201,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -201,12 +201,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
<< " op does not support input's dim is 1 in tensorrt."; << " op does not support input's dim is 1 in tensorrt.";
return false; return false;
} }
// TODO(inference): fix
if (x_shape.size() == 2 && !with_dynamic_shape) {
VLOG(3) << "activation op does not support input's dim is 2 in "
"tensorrt static shape, the output shape has diff.";
return false;
}
} }
if (op_type == "pool2d") { if (op_type == "pool2d") {
...@@ -410,12 +404,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -410,12 +404,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
// TODO(inference): fix
if (x_shape.size() == 2 && !with_dynamic_shape) {
VLOG(3) << "softmax op does not support input's dim is 2 in tensorrt "
"static shape, the output shape has diff.";
return false;
}
} }
if (op_type == "group_norm") { if (op_type == "group_norm") {
if (!with_dynamic_shape) return false; if (!with_dynamic_shape) return false;
...@@ -441,22 +429,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -441,22 +429,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false; return false;
} }
} }
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
// TODO(inference): fix
if (x_shape.size() == 2 && !with_dynamic_shape) {
VLOG(3) << "concat op does not support input's dim is 2 in tensorrt "
"static shape, the output shape has diff.";
return false;
}
} }
if (op_type == "transpose2" || op_type == "transpose") { if (op_type == "transpose2" || op_type == "transpose") {
if (!desc.HasAttr("axis")) { if (!desc.HasAttr("axis")) {
...@@ -756,12 +728,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -756,12 +728,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
// TODO(inference): fix
if (x_shape.size() == 2 && !with_dynamic_shape) {
VLOG(3) << "batch_norm op does not support input's dim is 2 in "
"tensorrt static shape, the output shape has diff.";
return false;
}
} }
if (op_type == "split") { if (op_type == "split") {
...@@ -849,13 +815,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -849,13 +815,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
VLOG(3) << "The output_length should be equal to the output size."; VLOG(3) << "The output_length should be equal to the output size.";
return false; return false;
} }
// TODO(inference): fix
if (x_shape.size() == 2 && !with_dynamic_shape) {
VLOG(3) << "split op does not support input's dim is 2 in tensorrt "
"static shape. The output shape has diff.";
return false;
}
} }
if (op_type == "scale") { if (op_type == "scale") {
auto scale_inputs = desc.Inputs(); auto scale_inputs = desc.Inputs();
if (scale_inputs.find("ScaleTensor") != scale_inputs.end()) { if (scale_inputs.find("ScaleTensor") != scale_inputs.end()) {
...@@ -873,11 +834,27 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -873,11 +834,27 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
if (!with_dynamic_shape && x_shape.size() == 1) return false; if (!with_dynamic_shape && x_shape.size() == 1) {
VLOG(3) << "Scale op does not support 1-dimensional input in tensorrt";
return false;
} }
}
if (op_type == "slice") { if (op_type == "slice") {
if (desc.HasAttr("decrease_axis")) {
std::vector<int> decrease_axis =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("decrease_axis"));
if (decrease_axis.size() > 0) {
VLOG(3) << "Invalid slice decrease_axis. decrease_axis.size() > 0"
"is not supported in TensorRT";
return false;
}
}
if (!desc.HasAttr("axes") || !desc.HasAttr("starts") || if (!desc.HasAttr("axes") || !desc.HasAttr("starts") ||
!desc.HasAttr("ends") || !desc.HasAttr("decrease_axis")) { !desc.HasAttr("ends")) {
VLOG(3) << "The necessary attributes of the slice operator axes "
"or starts or ends are missing.";
return false; return false;
} else { } else {
std::vector<int> axes = std::vector<int> axes =
...@@ -886,14 +863,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -886,14 +863,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("starts")); BOOST_GET_CONST(std::vector<int>, desc.GetAttr("starts"));
std::vector<int> ends = std::vector<int> ends =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("ends")); BOOST_GET_CONST(std::vector<int>, desc.GetAttr("ends"));
std::vector<int> decrease_axis =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("decrease_axis"));
if (axes.size() != starts.size() || axes.size() != ends.size()) { if (axes.size() != starts.size() || axes.size() != ends.size()) {
return false; VLOG(3) << "The shape of attributes of the slice operator axes "
} "or starts or ends are not equal.";
if (decrease_axis.size() > 0) {
VLOG(3) << "Invalid slice decrease_axis. decrease_axis.size() > 0"
"is not supported in TensorRT";
return false; return false;
} }
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
...@@ -1007,12 +980,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1007,12 +980,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
VLOG(3) << "gelu op does not support input's dim is 1 in tensorrt."; VLOG(3) << "gelu op does not support input's dim is 1 in tensorrt.";
return false; return false;
} }
// TODO(inference): fix
if (x_shape.size() == 2 && !with_dynamic_shape) {
VLOG(3) << "gelu op does not support input's dim is 2 in tensorrt "
"static shape, the output shape has diff.";
return false;
}
} }
if (op_type == "layer_norm") { if (op_type == "layer_norm") {
...@@ -1132,29 +1099,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1132,29 +1099,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
} }
} }
if (op_type == "scale") {
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) {
VLOG(3) << "scale op does not support input's dim is 1 in tensorrt.";
return false;
}
// TODO(inference): fix
if (x_shape.size() == 2 && !with_dynamic_shape) {
VLOG(3) << "scale op does not support input's dim is 2 in tensorrt "
"static shape, the output shape has diff.";
return false;
}
}
if (op_type == "swish") { if (op_type == "swish") {
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) { if (block == nullptr) {
...@@ -1170,12 +1114,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1170,12 +1114,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
VLOG(3) << "swish op does not support input's dim is 1 in tensorrt."; VLOG(3) << "swish op does not support input's dim is 1 in tensorrt.";
return false; return false;
} }
// TODO(inference): fix
if (x_shape.size() == 2 && !with_dynamic_shape) {
VLOG(3) << "swish op does not support input's dim is 2 in tensorrt "
"static shape, the output shape has diff.";
return false;
}
} }
if (op_type == "prelu") { if (op_type == "prelu") {
...@@ -1213,13 +1151,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1213,13 +1151,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false; return false;
} }
if (!with_dynamic_shape) {
if (x_shape.size() == 2) {
VLOG(3) << "prelu op does not support input's dim is 2 in tensorrt.";
return false;
}
}
#if IS_TRT_VERSION_LT(7000) #if IS_TRT_VERSION_LT(7000)
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
// TODO(inference): fix trt6 static plugin error. // TODO(inference): fix trt6 static plugin error.
...@@ -1397,12 +1328,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1397,12 +1328,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
VLOG(3) << "clip op does not support input's dim is 1 in tensorrt."; VLOG(3) << "clip op does not support input's dim is 1 in tensorrt.";
return false; return false;
} }
// TODO(inference): fix
if (x_shape.size() == 2 && !with_dynamic_shape) {
VLOG(3) << "clip op does not support input's dim is 2 in tensorrt "
"static shape, the output shape has diff.";
return false;
}
} }
if (op_type == "reduce_sum" || op_type == "reduce_mean") { if (op_type == "reduce_sum" || op_type == "reduce_mean") {
...@@ -1518,15 +1443,17 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1518,15 +1443,17 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) { if (block == nullptr) {
VLOG(3) << "The block is null."; VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false; return false;
} }
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() <= 2) { if (x_shape.size() == 1) {
VLOG(3) << "hard_sigmoid op does not support input's dim less than 3 " VLOG(3) << "Hard sigmoid does not support 1-dimensional input in "
"in tensorrt."; "tensorrt";
return false; return false;
} }
} }
......
...@@ -126,18 +126,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -126,18 +126,7 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
yield self.create_inference_config(), generate_trt_nodes_num(attrs, yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5 True), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if self.dims == 2:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"When input dims is 2, pulgin will product a 4 dims output.")
def test(self): def test(self):
self.add_skip_trt_case()
self.run_test() self.run_test()
......
...@@ -212,18 +212,6 @@ class TrtConvertBatchNormTest(TrtLayerAutoScanTest): ...@@ -212,18 +212,6 @@ class TrtConvertBatchNormTest(TrtLayerAutoScanTest):
self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT, self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT,
"INPUT MomentumTensor NOT SUPPORT") "INPUT MomentumTensor NOT SUPPORT")
def teller2(program_config, predictor_config):
if len(
program_config.inputs['batch_norm_input'].shape
) == 2 and not predictor_config.tensorrt_dynamic_shape_enabled():
return True
return False
self.add_skip_case(
teller2, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output shape has diff, but we can add shuffle layer to resolve it."
)
def test(self): def test(self):
self.add_skip_trt_case() self.add_skip_trt_case()
self.run_test() self.run_test()
......
...@@ -146,21 +146,7 @@ class TrtConvertClipTest(TrtLayerAutoScanTest): ...@@ -146,21 +146,7 @@ class TrtConvertClipTest(TrtLayerAutoScanTest):
yield self.create_inference_config(), generate_trt_nodes_num(attrs, yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5 True), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if len(
program_config.inputs['input_data'].shape
) == 2 and not predictor_config.tensorrt_dynamic_shape_enabled():
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output shape has diff, but we can add shuffle layer to resolve it."
)
def test(self): def test(self):
self.add_skip_trt_case()
self.run_test() self.run_test()
......
...@@ -318,18 +318,6 @@ class TrtConvertConcatTest(TrtLayerAutoScanTest): ...@@ -318,18 +318,6 @@ class TrtConvertConcatTest(TrtLayerAutoScanTest):
self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT, self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT,
"INPUT AxisTensor NOT SUPPORT") "INPUT AxisTensor NOT SUPPORT")
def teller2(program_config, predictor_config):
if len(
program_config.inputs['concat_input1'].shape
) == 2 and not predictor_config.tensorrt_dynamic_shape_enabled():
return True
return False
self.add_skip_case(
teller2, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output shape has diff, but we can add shuffle layer to resolve it."
)
def test(self): def test(self):
self.add_skip_trt_case() self.add_skip_trt_case()
self.run_test() self.run_test()
......
...@@ -126,18 +126,7 @@ class TrtConvertGeluTest(TrtLayerAutoScanTest): ...@@ -126,18 +126,7 @@ class TrtConvertGeluTest(TrtLayerAutoScanTest):
yield self.create_inference_config(), generate_trt_nodes_num(attrs, yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5 True), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if self.dims == 2:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"When input dims is 2, pulgin will product a 4 dims output.")
def test(self): def test(self):
self.add_skip_trt_case()
self.run_test() self.run_test()
......
...@@ -106,20 +106,7 @@ class TrtConvertHardSigmoidTest_dim_2(TrtLayerAutoScanTest): ...@@ -106,20 +106,7 @@ class TrtConvertHardSigmoidTest_dim_2(TrtLayerAutoScanTest):
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 2), 1e-5 yield self.create_inference_config(), (1, 2), 1e-5
def add_skip_trt_case(self):
def teller(program_config, predictor_config):
if len(self.dynamic_shape.
min_input_shape) == 0 and self.input_dim == 2:
return True
return False
self.add_skip_case(
teller, SkipReasons.TRT_NOT_SUPPORT,
"Need to repair the case: the output of trt and GPU has diff when inputs' dims is 2 in static shape mode."
)
def test(self): def test(self):
self.add_skip_trt_case()
self.run_test() self.run_test()
......
...@@ -176,17 +176,6 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest): ...@@ -176,17 +176,6 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT, self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT,
"Trt does not support 1-dimensional input.") "Trt does not support 1-dimensional input.")
def teller2(program_config, predictor_config):
if (len(self.dynamic_shape.min_input_shape) == 0):
if self.dim1 != 0 and self.dim2 == 0 and self.dim3 == 0:
return True
return False
self.add_skip_case(
teller2, SkipReasons.TRT_NOT_SUPPORT,
"Need to repair the case: the output of GPU and tensorrt has diff when the input dimension is 2 in static shape mode."
)
ver = paddle_infer.get_trt_compile_version() ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000: if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000:
......
...@@ -145,7 +145,7 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest): ...@@ -145,7 +145,7 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest):
def add_skip_trt_case(self): def add_skip_trt_case(self):
def teller1(program_config, predictor_config): def teller1(program_config, predictor_config):
if len(program_config.weights) == 1: if self.num_input == 0:
return True return True
return False return False
...@@ -153,7 +153,7 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest): ...@@ -153,7 +153,7 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest):
"INPUT ScaleTensor and Shape NOT SUPPORT") "INPUT ScaleTensor and Shape NOT SUPPORT")
def teller2(program_config, predictor_config): def teller2(program_config, predictor_config):
if self.dims == 1 and self.dynamic_shape.min_input_shape == 0: if self.dims == 1 and len(self.dynamic_shape.min_input_shape) == 0:
return True return True
return False return False
......
...@@ -135,21 +135,7 @@ class TrtConvertSoftmaxTest(TrtLayerAutoScanTest): ...@@ -135,21 +135,7 @@ class TrtConvertSoftmaxTest(TrtLayerAutoScanTest):
yield self.create_inference_config(), generate_trt_nodes_num(attrs, yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5 True), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if len(
program_config.inputs['softmax_input'].shape
) == 2 and not predictor_config.tensorrt_dynamic_shape_enabled():
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output shape has diff, but we can add shuffle layer to resolve it."
)
def test(self): def test(self):
self.add_skip_trt_case()
self.run_test() self.run_test()
......
...@@ -227,18 +227,6 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest): ...@@ -227,18 +227,6 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
teller1, SkipReasons.TRT_NOT_SUPPORT, teller1, SkipReasons.TRT_NOT_SUPPORT,
"INPUT AxisTensor AND SectionsTensorList NOT SUPPORT.") "INPUT AxisTensor AND SectionsTensorList NOT SUPPORT.")
def teller2(program_config, predictor_config):
if len(
program_config.inputs['split_input'].shape
) == 2 and not predictor_config.tensorrt_dynamic_shape_enabled():
return True
return False
self.add_skip_case(
teller2, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output shape has diff, but we can add shuffle layer to resolve it."
)
def test(self): def test(self):
self.add_skip_trt_case() self.add_skip_trt_case()
self.run_test() self.run_test()
......
...@@ -126,18 +126,7 @@ class TrtConvertSwishTest(TrtLayerAutoScanTest): ...@@ -126,18 +126,7 @@ class TrtConvertSwishTest(TrtLayerAutoScanTest):
yield self.create_inference_config(), generate_trt_nodes_num(attrs, yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5 True), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if self.dims == 2:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"When input dims is 2, pulgin will product a 4 dims output.")
def test(self): def test(self):
self.add_skip_trt_case()
self.run_test() self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册