未验证 提交 292b7254 编写于 作者: C ccrrong 提交者: GitHub

add arg max trt converter support dynamic shape mode (#43473)

* fix arg_max converter
上级 e176cc40
...@@ -40,7 +40,9 @@ class ArgMaxOpConverter : public OpConverter { ...@@ -40,7 +40,9 @@ class ArgMaxOpConverter : public OpConverter {
int axis = op_desc.HasAttr("axis") int axis = op_desc.HasAttr("axis")
? BOOST_GET_CONST(int64_t, op_desc.GetAttr("axis")) ? BOOST_GET_CONST(int64_t, op_desc.GetAttr("axis"))
: -1; : -1;
if (axis > 0) axis -= 1; if (axis > 0 && !engine_->with_dynamic_shape()) {
axis -= 1;
}
if (axis < 0) axis += rank; if (axis < 0) axis += rank;
auto* topk_layer = TRT_ENGINE_ADD_LAYER( auto* topk_layer = TRT_ENGINE_ADD_LAYER(
engine_, TopK, *input, nvinfer1::TopKOperation::kMAX, 1, 1 << axis); engine_, TopK, *input, nvinfer1::TopKOperation::kMAX, 1, 1 << axis);
......
...@@ -738,7 +738,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -738,7 +738,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
} }
if (op_type == "arg_max") { if (op_type == "arg_max") {
if (with_dynamic_shape) return false;
int axis = desc.HasAttr("axis") int axis = desc.HasAttr("axis")
? BOOST_GET_CONST(int64_t, desc.GetAttr("axis")) ? BOOST_GET_CONST(int64_t, desc.GetAttr("axis"))
: -1; : -1;
......
...@@ -45,6 +45,7 @@ class TrtConvertArgMaxTest(TrtLayerAutoScanTest): ...@@ -45,6 +45,7 @@ class TrtConvertArgMaxTest(TrtLayerAutoScanTest):
for batch in [1, 4]: for batch in [1, 4]:
for axis in [-1, 0, 1, 2, 3]: for axis in [-1, 0, 1, 2, 3]:
for keepdims in [True, False]: for keepdims in [True, False]:
self.rank = rank
flatten = False flatten = False
dtype = 2 dtype = 2
ops_config = [{ ops_config = [{
...@@ -76,9 +77,59 @@ class TrtConvertArgMaxTest(TrtLayerAutoScanTest): ...@@ -76,9 +77,59 @@ class TrtConvertArgMaxTest(TrtLayerAutoScanTest):
def sample_predictor_configs( def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float): self, program_config) -> (paddle_infer.Config, List[int], float):
self.trt_param.precision = paddle_infer.PrecisionType.Float32
def generate_dynamic_shape(attrs):
if self.rank == 3:
self.dynamic_shape.min_input_shape = {
"arg_max_input": [1, 8, 16]
}
self.dynamic_shape.max_input_shape = {
"arg_max_input": [4, 8, 16]
}
self.dynamic_shape.opt_input_shape = {
"arg_max_input": [3, 8, 16]
}
else:
self.dynamic_shape.min_input_shape = {
"arg_max_input": [1, 8, 16, 24]
}
self.dynamic_shape.max_input_shape = {
"arg_max_input": [4, 8, 16, 24]
}
self.dynamic_shape.opt_input_shape = {
"arg_max_input": [1, 8, 16, 24]
}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 2
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
self.trt_param.workspace_size = 1024000 self.trt_param.workspace_size = 1024000
yield self.create_inference_config(), [1, 2], 1e-5 # for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-5
def test(self): def test(self):
self.run_test() self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册