未验证 提交 aebff6d7 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle-Inference] Support trt 0dims of expand_as_v2 and mish. (#53627)

* support_expand_mish
上级 08b6f5d6
...@@ -1890,8 +1890,10 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -1890,8 +1890,10 @@ struct SimpleOpTypeSetTeller : public Teller {
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) { if ((!with_dynamic_shape && x_shape.size() == 1) || x_shape.size() == 0) {
VLOG(3) << "mish op does not support input's dim is 1 in tensorrt."; VLOG(3) << op_type
<< "mish op does not support input's dim is 1 in tensorrt "
"static shape mode or 0.";
return false; return false;
} }
} }
...@@ -2644,6 +2646,15 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2644,6 +2646,15 @@ struct SimpleOpTypeSetTeller : public Teller {
"the pass."; "the pass.";
return false; return false;
} }
#if IS_TRT_VERSION_LT(8000)
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 0) {
return false; // not supported 0 dim.
}
#endif
} }
if (op_type == "grid_sampler") { if (op_type == "grid_sampler") {
......
...@@ -49,8 +49,11 @@ class TrtConvertExpandASV2Test(TrtLayerAutoScanTest): ...@@ -49,8 +49,11 @@ class TrtConvertExpandASV2Test(TrtLayerAutoScanTest):
elif self.dims == 1: elif self.dims == 1:
self.input_shape = [32] self.input_shape = [32]
return np.random.random([32]).astype(np.float32) return np.random.random([32]).astype(np.float32)
elif self.dims == 0:
self.input_shape = []
return np.random.random([]).astype(np.float32)
for dims in [1, 2, 3, 4]: for dims in [0, 1, 2, 3, 4]:
for shape in [ for shape in [
[10, 8, 32, 32], [10, 8, 32, 32],
[2, 8, 32, 32], [2, 8, 32, 32],
...@@ -125,6 +128,10 @@ class TrtConvertExpandASV2Test(TrtLayerAutoScanTest): ...@@ -125,6 +128,10 @@ class TrtConvertExpandASV2Test(TrtLayerAutoScanTest):
self.dynamic_shape.min_input_shape = {"expand_v2_input": [32]} self.dynamic_shape.min_input_shape = {"expand_v2_input": [32]}
self.dynamic_shape.max_input_shape = {"expand_v2_input": [64]} self.dynamic_shape.max_input_shape = {"expand_v2_input": [64]}
self.dynamic_shape.opt_input_shape = {"expand_v2_input": [32]} self.dynamic_shape.opt_input_shape = {"expand_v2_input": [32]}
elif self.dims == 0:
self.dynamic_shape.min_input_shape = {"expand_v2_input": []}
self.dynamic_shape.max_input_shape = {"expand_v2_input": []}
self.dynamic_shape.opt_input_shape = {"expand_v2_input": []}
def clear_dynamic_shape(): def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {} self.dynamic_shape.min_input_shape = {}
...@@ -132,7 +139,9 @@ class TrtConvertExpandASV2Test(TrtLayerAutoScanTest): ...@@ -132,7 +139,9 @@ class TrtConvertExpandASV2Test(TrtLayerAutoScanTest):
self.dynamic_shape.opt_input_shape = {} self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape): def generate_trt_nodes_num(attrs, dynamic_shape):
if dynamic_shape: ver = paddle_infer.get_trt_compile_version()
ver_num = ver[0] * 1000 + ver[1] * 100 + ver[2] * 10
if dynamic_shape and (ver_num > 8000 or self.dims > 0):
return 1, 2 return 1, 2
else: else:
return 0, 3 return 0, 3
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册