diff --git a/paddle/fluid/inference/tensorrt/convert/arg_max_op.cc b/paddle/fluid/inference/tensorrt/convert/arg_max_op.cc index 14975e481644da031195ad0b70941495e78f19e5..df701324a5077c331aceebc0a5d490f7ac799909 100644 --- a/paddle/fluid/inference/tensorrt/convert/arg_max_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/arg_max_op.cc @@ -40,7 +40,9 @@ class ArgMaxOpConverter : public OpConverter { int axis = op_desc.HasAttr("axis") ? BOOST_GET_CONST(int64_t, op_desc.GetAttr("axis")) : -1; - if (axis > 0) axis -= 1; + if (axis > 0 && !engine_->with_dynamic_shape()) { + axis -= 1; + } if (axis < 0) axis += rank; auto* topk_layer = TRT_ENGINE_ADD_LAYER( engine_, TopK, *input, nvinfer1::TopKOperation::kMAX, 1, 1 << axis); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index ece3451184d0511825d49ebdbd6b74cd5b8763b7..f9086d7a822c336255d6e27a8c4f4ddcf88c7100 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -738,7 +738,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } if (op_type == "arg_max") { - if (with_dynamic_shape) return false; int axis = desc.HasAttr("axis") ? BOOST_GET_CONST(int64_t, desc.GetAttr("axis")) : -1; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_arg_max.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_arg_max.py index 82ac600fd1e73f133bbd0f578c3fd1e55bdcb1f2..8d01029c78a7d96ddd4c5d1b77f44577277d3998 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_arg_max.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_arg_max.py @@ -45,6 +45,7 @@ class TrtConvertArgMaxTest(TrtLayerAutoScanTest): for batch in [1, 4]: for axis in [-1, 0, 1, 2, 3]: for keepdims in [True, False]: + self.rank = rank flatten = False dtype = 2 ops_config = [{ @@ -76,9 +77,59 @@ class TrtConvertArgMaxTest(TrtLayerAutoScanTest): def sample_predictor_configs( self, program_config) -> (paddle_infer.Config, List[int], float): - self.trt_param.precision = paddle_infer.PrecisionType.Float32 + + def generate_dynamic_shape(attrs): + if self.rank == 3: + self.dynamic_shape.min_input_shape = { + "arg_max_input": [1, 8, 16] + } + self.dynamic_shape.max_input_shape = { + "arg_max_input": [4, 8, 16] + } + self.dynamic_shape.opt_input_shape = { + "arg_max_input": [3, 8, 16] + } + else: + self.dynamic_shape.min_input_shape = { + "arg_max_input": [1, 8, 16, 24] + } + self.dynamic_shape.max_input_shape = { + "arg_max_input": [4, 8, 16, 24] + } + self.dynamic_shape.opt_input_shape = { + "arg_max_input": [1, 8, 16, 24] + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + self.trt_param.workspace_size = 1024000 - yield self.create_inference_config(), [1, 2], 1e-5 + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-5 def test(self): self.run_test()