未验证 提交 82c73884 编写于 作者: Y Yuanle Liu 提交者: GitHub

[inference Zero-Dim]prelu trt converter support zero dim tensor (#53634)

* prelu op trt converter support zero dim
上级 5417382d
...@@ -356,6 +356,8 @@ void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const { ...@@ -356,6 +356,8 @@ void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const {
} }
}; };
InsertTransposeOp(); InsertTransposeOp();
AddStatis(transposed_ops.size());
} }
} // namespace ir } // namespace ir
......
...@@ -87,7 +87,6 @@ class PReluOpConverter : public OpConverter { ...@@ -87,7 +87,6 @@ class PReluOpConverter : public OpConverter {
if (hw_tensor != nullptr) { if (hw_tensor != nullptr) {
shape_tensor = Concat( shape_tensor = Concat(
std::vector<nvinfer1::ITensor*>{n_tensor, c_tensor, hw_tensor}); std::vector<nvinfer1::ITensor*>{n_tensor, c_tensor, hw_tensor});
} else { } else {
shape_tensor = shape_tensor =
Concat(std::vector<nvinfer1::ITensor*>{n_tensor, c_tensor}); Concat(std::vector<nvinfer1::ITensor*>{n_tensor, c_tensor});
......
...@@ -1837,28 +1837,28 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -1837,28 +1837,28 @@ struct SimpleOpTypeSetTeller : public Teller {
"the pass."; "the pass.";
return false; return false;
} }
auto* var_desc = block->FindVar(desc.Input("Alpha")[0]); auto* alpha_var = block->FindVar(desc.Input("Alpha")[0]);
if (!var_desc) { if (!alpha_var) {
VLOG(3) << "Variable Alpha of prelu TRT converter not found."; VLOG(3) << "Variable Alpha of prelu TRT converter not found.";
return false; return false;
} }
auto alpha_shape = alpha_var->GetShape();
auto x_var_name = desc.Input("X")[0]; if (!with_dynamic_shape && alpha_shape.size() == 0) {
auto* x_var_desc = block->FindVar(x_var_name); VLOG(3) << op_type
const auto x_shape = x_var_desc->GetShape(); << " op does not support alpha's dim is 0 in tensorrt "
if (!with_dynamic_shape && x_shape.size() == 1) { "static shape mode.";
VLOG(3) << "prelu op does not support input's dim is 1 in tensorrt "
"with static shape.";
return false; return false;
} }
#if IS_TRT_VERSION_LT(7000) auto x_var_name = desc.Input("X")[0];
if (!with_dynamic_shape) { auto* x_var = block->FindVar(x_var_name);
// TODO(inference): fix trt6 static plugin error. const auto x_shape = x_var->GetShape();
VLOG(3) << "prelu static plugin in trt6 has bug."; if (!with_dynamic_shape && (x_shape.size() == 1 || x_shape.size() == 0)) {
VLOG(3) << op_type
<< " op does not support input's dim is 1 or 0 in tensorrt "
"with static shape.";
return false; return false;
} }
#endif
} }
if (op_type == "mish") { if (op_type == "mish") {
......
...@@ -2340,7 +2340,7 @@ void PReluInferMeta(const MetaTensor& x, ...@@ -2340,7 +2340,7 @@ void PReluInferMeta(const MetaTensor& x,
1, 1,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"For mode 'element', rank of input X must be " "For mode 'element', rank of input X must be "
"equal or larger than 2. But recevied X's " "equal or larger than 1. But recevied X's "
"rank: %d", "rank: %d",
x_rank)); x_rank));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -18,7 +18,7 @@ from typing import Any, Dict, List ...@@ -18,7 +18,7 @@ from typing import Any, Dict, List
import numpy as np import numpy as np
from program_config import ProgramConfig, TensorConfig from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import SkipReasons, TrtLayerAutoScanTest from trt_layer_auto_scan_test import TrtLayerAutoScanTest
import paddle.inference as paddle_infer import paddle.inference as paddle_infer
...@@ -28,170 +28,165 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest): ...@@ -28,170 +28,165 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
return True return True
def sample_program_configs(self): def sample_program_configs(self):
def generate_input(batch, dim1, dim2, dim3): def generate_input(attrs: List[Dict[str, Any]], batch):
shape = [batch] if self.dims == 0:
if dim1 != 0: return np.random.random([]).astype(np.float32)
shape.append(dim1) elif self.dims == 1:
if dim2 != 0: return np.random.random([16]).astype(np.float32)
shape.append(dim2) elif self.dims == 2:
if dim3 != 0: return np.random.random([1, 3]).astype(np.float32)
shape.append(dim3) elif self.dims == 3:
return np.random.random(shape).astype(np.float32) if attrs[0]["data_format"] == "NCHW":
return np.random.random([batch, 3, 16]).astype(np.float32)
def generate_alpha(attrs: List[Dict[str, Any]], dim1, dim2, dim3): elif attrs[0]["data_format"] == "NHWC":
return np.random.random([batch, 16, 3]).astype(np.float32)
else:
raise AssertionError()
else:
if attrs[0]["data_format"] == "NCHW":
return np.random.random([batch, 3, 16, 32]).astype(
np.float32
)
else:
return np.random.random([batch, 16, 32, 3]).astype(
np.float32
)
def generate_alpha(attrs: List[Dict[str, Any]]):
if self.dims == 0:
return np.random.random([]).astype(np.float32)
if attrs[0]["mode"] == "all": if attrs[0]["mode"] == "all":
return np.random.random(size=(1)).astype(np.float32) return np.random.random([1]).astype(np.float32)
elif ( elif attrs[0]["mode"] == "channel":
attrs[0]["mode"] == "channel" return np.random.random([3]).astype(np.float32)
and attrs[0]["data_format"] == "NCHW"
):
shape = [1]
if dim1 != 0:
shape.append(dim1)
if dim2 != 0:
shape.append(dim2)
if dim3 != 0:
shape.append(dim3)
return np.random.random(size=shape[1]).astype(np.float32)
elif (
attrs[0]["mode"] == "channel"
and attrs[0]["data_format"] == "NHWC"
):
shape = [1]
if dim1 != 0:
shape.append(dim1)
if dim2 != 0:
shape.append(dim2)
if dim3 != 0:
shape.append(dim3)
return np.random.random(size=shape[-1]).astype(np.float32)
elif attrs[0]["mode"] == "element": elif attrs[0]["mode"] == "element":
shape = [1] if self.dims == 1:
if dim1 != 0: return np.random.random([16]).astype(np.float32)
shape.append(dim1) elif self.dims == 2:
if dim2 != 0: return np.random.random([1, 3]).astype(np.float32)
shape.append(dim2) elif self.dims == 3:
if dim3 != 0: if attrs[0]["data_format"] == "NCHW":
shape.append(dim3) return np.random.random([1, 3, 16]).astype(np.float32)
return np.random.random(size=shape).astype(np.float32) elif attrs[0]["data_format"] == "NHWC":
return np.random.random([1, 16, 3]).astype(np.float32)
else:
raise AssertionError()
else:
if attrs[0]["data_format"] == "NCHW":
return np.random.random([1, 3, 16, 32]).astype(
np.float32
)
elif attrs[0]["data_format"] == "NHWC":
return np.random.random([1, 16, 32, 3]).astype(
np.float32
)
else:
raise AssertionError()
for batch in [1, 4]: for batch in [1, 4]:
for dim1 in [0, 3]: for dims in [0, 1, 2, 3, 4]:
for dim2 in [0, 16]: for mode in ["all", "element", "channel"]:
for dim3 in [0, 32]: for data_format in ["NCHW", "NHWC"]:
self.dim1 = dim1 if (mode == "element" or mode == "all") and dims == 0:
self.dim2 = dim2
self.dim3 = dim3
if dim1 == 0 and dim2 != 0:
continue continue
if dim1 == 0 and dim2 == 0 and dim3 != 0: if mode == "channel" and dims != 4:
continue continue
self.dims = dims
for mode in ["all", "channel", "element"]: dics = [{"mode": mode, "data_format": data_format}]
for data_format in ['NCHW', 'NHWC']: ops_config = [
if ( {
mode == "channel" "op_type": "prelu",
and dim1 == 0 "op_inputs": {
and data_format == "NCHW" "X": ["input_data"],
): "Alpha": ["alpha_weight"],
continue },
if ( "op_outputs": {"Out": ["output_data"]},
mode == "channel" "op_attrs": dics[0],
and dim3 == 0 }
and data_format == "NHWC" ]
): ops = self.generate_op_config(ops_config)
continue
dics = [ program_config = ProgramConfig(
{"mode": mode, "data_format": data_format} ops=ops,
] weights={
ops_config = [ "alpha_weight": TensorConfig(
{ data_gen=partial(generate_alpha, dics)
"op_type": "prelu",
"op_inputs": {
"X": ["input_data"],
"Alpha": ["alpha_weight"],
},
"op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[0],
}
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"alpha_weight": TensorConfig(
data_gen=partial(
generate_alpha,
dics,
dim1,
dim2,
dim3,
)
)
},
inputs={
"input_data": TensorConfig(
data_gen=partial(
generate_input,
batch,
dim1,
dim2,
dim3,
)
),
},
outputs=["output_data"],
) )
},
yield program_config inputs={
"input_data": TensorConfig(
data_gen=partial(
generate_input, dics, batch
)
),
},
outputs=["output_data"],
)
yield program_config
def sample_predictor_configs( def sample_predictor_configs(
self, program_config self, program_config
) -> (paddle_infer.Config, List[int], float): ) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs): def generate_dynamic_shape(attrs):
if self.dim1 == 0: if self.dims == 0:
self.dynamic_shape.min_input_shape = { self.dynamic_shape.min_input_shape = {"input_data": []}
"input_data": [1], self.dynamic_shape.max_input_shape = {"input_data": []}
} self.dynamic_shape.opt_input_shape = {"input_data": []}
self.dynamic_shape.max_input_shape = { elif self.dims == 1:
"input_data": [4], self.dynamic_shape.min_input_shape = {"input_data": [16]}
} self.dynamic_shape.max_input_shape = {"input_data": [16]}
self.dynamic_shape.opt_input_shape = { self.dynamic_shape.opt_input_shape = {"input_data": [16]}
"input_data": [2], elif self.dims == 2:
} self.dynamic_shape.min_input_shape = {"input_data": [1, 3]}
else: self.dynamic_shape.max_input_shape = {"input_data": [1, 3]}
if self.dim2 == 0 and self.dim3 == 0: self.dynamic_shape.opt_input_shape = {"input_data": [1, 3]}
elif self.dims == 3:
if attrs[0]["data_format"] == "NCHW":
self.dynamic_shape.min_input_shape = { self.dynamic_shape.min_input_shape = {
"input_data": [1, 1], "input_data": [1, 3, 16]
} }
self.dynamic_shape.max_input_shape = { self.dynamic_shape.max_input_shape = {
"input_data": [4, 32], "input_data": [4, 3, 16]
} }
self.dynamic_shape.opt_input_shape = { self.dynamic_shape.opt_input_shape = {
"input_data": [2, 3], "input_data": [1, 3, 16]
} }
elif self.dim2 != 0 and self.dim3 != 0: elif attrs[0]["data_format"] == "NHWC":
self.dynamic_shape.min_input_shape = { self.dynamic_shape.min_input_shape = {
"input_data": [1, 1, 1, 1], "input_data": [1, 16, 3]
} }
self.dynamic_shape.max_input_shape = { self.dynamic_shape.max_input_shape = {
"input_data": [4, 3, 16, 32], "input_data": [4, 16, 3]
} }
self.dynamic_shape.opt_input_shape = { self.dynamic_shape.opt_input_shape = {
"input_data": [2, 3, 16, 32], "input_data": [1, 16, 3]
} }
elif self.dim3 == 0: else:
raise AssertionError()
else:
if attrs[0]["data_format"] == "NCHW":
self.dynamic_shape.min_input_shape = { self.dynamic_shape.min_input_shape = {
"input_data": [1, 1, 1], "input_data": [1, 3, 16, 32]
} }
self.dynamic_shape.max_input_shape = { self.dynamic_shape.max_input_shape = {
"input_data": [4, 3, 32], "input_data": [4, 3, 16, 32]
} }
self.dynamic_shape.opt_input_shape = { self.dynamic_shape.opt_input_shape = {
"input_data": [2, 3, 16], "input_data": [1, 3, 16, 32]
} }
elif attrs[0]["data_format"] == "NHWC":
self.dynamic_shape.min_input_shape = {
"input_data": [1, 16, 32, 3]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 16, 32, 3]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [1, 16, 32, 3]
}
else:
raise AssertionError()
def clear_dynamic_shape(): def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {} self.dynamic_shape.max_input_shape = {}
...@@ -203,12 +198,7 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest): ...@@ -203,12 +198,7 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
] ]
def generate_trt_nodes_num(attrs, dynamic_shape): def generate_trt_nodes_num(attrs, dynamic_shape):
if ( if not dynamic_shape and (self.dims == 1 or self.dims == 0):
not dynamic_shape
and self.dim1 == 0
and self.dim2 == 0
and self.dim3 == 0
):
return 0, 3 return 0, 3
return 1, 2 return 1, 2
...@@ -234,23 +224,7 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest): ...@@ -234,23 +224,7 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
attrs, True attrs, True
), (1e-3, 1e-3) ), (1e-3, 1e-3)
def add_skip_trt_case(self):
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000:
def teller(program_config, predictor_config):
if not predictor_config.tensorrt_dynamic_shape_enabled():
return True
return False
self.add_skip_case(
teller,
SkipReasons.TRT_NOT_IMPLEMENTED,
"Need to repair the case: the output of GPU and tensorrt has diff in trt6, the prelu static plugin has bug.",
)
def test(self): def test(self):
self.add_skip_trt_case()
self.run_test() self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册