未验证 提交 82c73884 编写于 作者: Y Yuanle Liu 提交者: GitHub

[inference Zero-Dim]prelu trt converter support zero dim tensor (#53634)

* prelu op trt converter support zero dim
上级 5417382d
......@@ -356,6 +356,8 @@ void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const {
}
};
InsertTransposeOp();
AddStatis(transposed_ops.size());
}
} // namespace ir
......
......@@ -87,7 +87,6 @@ class PReluOpConverter : public OpConverter {
if (hw_tensor != nullptr) {
shape_tensor = Concat(
std::vector<nvinfer1::ITensor*>{n_tensor, c_tensor, hw_tensor});
} else {
shape_tensor =
Concat(std::vector<nvinfer1::ITensor*>{n_tensor, c_tensor});
......
......@@ -1837,28 +1837,28 @@ struct SimpleOpTypeSetTeller : public Teller {
"the pass.";
return false;
}
auto* var_desc = block->FindVar(desc.Input("Alpha")[0]);
if (!var_desc) {
auto* alpha_var = block->FindVar(desc.Input("Alpha")[0]);
if (!alpha_var) {
VLOG(3) << "Variable Alpha of prelu TRT converter not found.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (!with_dynamic_shape && x_shape.size() == 1) {
VLOG(3) << "prelu op does not support input's dim is 1 in tensorrt "
"with static shape.";
auto alpha_shape = alpha_var->GetShape();
if (!with_dynamic_shape && alpha_shape.size() == 0) {
VLOG(3) << op_type
<< " op does not support alpha's dim is 0 in tensorrt "
"static shape mode.";
return false;
}
#if IS_TRT_VERSION_LT(7000)
if (!with_dynamic_shape) {
// TODO(inference): fix trt6 static plugin error.
VLOG(3) << "prelu static plugin in trt6 has bug.";
auto x_var_name = desc.Input("X")[0];
auto* x_var = block->FindVar(x_var_name);
const auto x_shape = x_var->GetShape();
if (!with_dynamic_shape && (x_shape.size() == 1 || x_shape.size() == 0)) {
VLOG(3) << op_type
<< " op does not support input's dim is 1 or 0 in tensorrt "
"with static shape.";
return false;
}
#endif
}
if (op_type == "mish") {
......
......@@ -2340,7 +2340,7 @@ void PReluInferMeta(const MetaTensor& x,
1,
phi::errors::InvalidArgument(
"For mode 'element', rank of input X must be "
"equal or larger than 2. But recevied X's "
"equal or larger than 1. But recevied X's "
"rank: %d",
x_rank));
PADDLE_ENFORCE_EQ(
......
......@@ -18,7 +18,7 @@ from typing import Any, Dict, List
import numpy as np
from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import SkipReasons, TrtLayerAutoScanTest
from trt_layer_auto_scan_test import TrtLayerAutoScanTest
import paddle.inference as paddle_infer
......@@ -28,170 +28,165 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
return True
def sample_program_configs(self):
def generate_input(batch, dim1, dim2, dim3):
shape = [batch]
if dim1 != 0:
shape.append(dim1)
if dim2 != 0:
shape.append(dim2)
if dim3 != 0:
shape.append(dim3)
return np.random.random(shape).astype(np.float32)
def generate_alpha(attrs: List[Dict[str, Any]], dim1, dim2, dim3):
def generate_input(attrs: List[Dict[str, Any]], batch):
if self.dims == 0:
return np.random.random([]).astype(np.float32)
elif self.dims == 1:
return np.random.random([16]).astype(np.float32)
elif self.dims == 2:
return np.random.random([1, 3]).astype(np.float32)
elif self.dims == 3:
if attrs[0]["data_format"] == "NCHW":
return np.random.random([batch, 3, 16]).astype(np.float32)
elif attrs[0]["data_format"] == "NHWC":
return np.random.random([batch, 16, 3]).astype(np.float32)
else:
raise AssertionError()
else:
if attrs[0]["data_format"] == "NCHW":
return np.random.random([batch, 3, 16, 32]).astype(
np.float32
)
else:
return np.random.random([batch, 16, 32, 3]).astype(
np.float32
)
def generate_alpha(attrs: List[Dict[str, Any]]):
if self.dims == 0:
return np.random.random([]).astype(np.float32)
if attrs[0]["mode"] == "all":
return np.random.random(size=(1)).astype(np.float32)
elif (
attrs[0]["mode"] == "channel"
and attrs[0]["data_format"] == "NCHW"
):
shape = [1]
if dim1 != 0:
shape.append(dim1)
if dim2 != 0:
shape.append(dim2)
if dim3 != 0:
shape.append(dim3)
return np.random.random(size=shape[1]).astype(np.float32)
elif (
attrs[0]["mode"] == "channel"
and attrs[0]["data_format"] == "NHWC"
):
shape = [1]
if dim1 != 0:
shape.append(dim1)
if dim2 != 0:
shape.append(dim2)
if dim3 != 0:
shape.append(dim3)
return np.random.random(size=shape[-1]).astype(np.float32)
return np.random.random([1]).astype(np.float32)
elif attrs[0]["mode"] == "channel":
return np.random.random([3]).astype(np.float32)
elif attrs[0]["mode"] == "element":
shape = [1]
if dim1 != 0:
shape.append(dim1)
if dim2 != 0:
shape.append(dim2)
if dim3 != 0:
shape.append(dim3)
return np.random.random(size=shape).astype(np.float32)
if self.dims == 1:
return np.random.random([16]).astype(np.float32)
elif self.dims == 2:
return np.random.random([1, 3]).astype(np.float32)
elif self.dims == 3:
if attrs[0]["data_format"] == "NCHW":
return np.random.random([1, 3, 16]).astype(np.float32)
elif attrs[0]["data_format"] == "NHWC":
return np.random.random([1, 16, 3]).astype(np.float32)
else:
raise AssertionError()
else:
if attrs[0]["data_format"] == "NCHW":
return np.random.random([1, 3, 16, 32]).astype(
np.float32
)
elif attrs[0]["data_format"] == "NHWC":
return np.random.random([1, 16, 32, 3]).astype(
np.float32
)
else:
raise AssertionError()
for batch in [1, 4]:
for dim1 in [0, 3]:
for dim2 in [0, 16]:
for dim3 in [0, 32]:
self.dim1 = dim1
self.dim2 = dim2
self.dim3 = dim3
if dim1 == 0 and dim2 != 0:
for dims in [0, 1, 2, 3, 4]:
for mode in ["all", "element", "channel"]:
for data_format in ["NCHW", "NHWC"]:
if (mode == "element" or mode == "all") and dims == 0:
continue
if dim1 == 0 and dim2 == 0 and dim3 != 0:
if mode == "channel" and dims != 4:
continue
for mode in ["all", "channel", "element"]:
for data_format in ['NCHW', 'NHWC']:
if (
mode == "channel"
and dim1 == 0
and data_format == "NCHW"
):
continue
if (
mode == "channel"
and dim3 == 0
and data_format == "NHWC"
):
continue
dics = [
{"mode": mode, "data_format": data_format}
]
ops_config = [
{
"op_type": "prelu",
"op_inputs": {
"X": ["input_data"],
"Alpha": ["alpha_weight"],
},
"op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[0],
}
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"alpha_weight": TensorConfig(
data_gen=partial(
generate_alpha,
dics,
dim1,
dim2,
dim3,
)
)
},
inputs={
"input_data": TensorConfig(
data_gen=partial(
generate_input,
batch,
dim1,
dim2,
dim3,
)
),
},
outputs=["output_data"],
self.dims = dims
dics = [{"mode": mode, "data_format": data_format}]
ops_config = [
{
"op_type": "prelu",
"op_inputs": {
"X": ["input_data"],
"Alpha": ["alpha_weight"],
},
"op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[0],
}
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"alpha_weight": TensorConfig(
data_gen=partial(generate_alpha, dics)
)
yield program_config
},
inputs={
"input_data": TensorConfig(
data_gen=partial(
generate_input, dics, batch
)
),
},
outputs=["output_data"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
if self.dim1 == 0:
self.dynamic_shape.min_input_shape = {
"input_data": [1],
}
self.dynamic_shape.max_input_shape = {
"input_data": [4],
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2],
}
else:
if self.dim2 == 0 and self.dim3 == 0:
if self.dims == 0:
self.dynamic_shape.min_input_shape = {"input_data": []}
self.dynamic_shape.max_input_shape = {"input_data": []}
self.dynamic_shape.opt_input_shape = {"input_data": []}
elif self.dims == 1:
self.dynamic_shape.min_input_shape = {"input_data": [16]}
self.dynamic_shape.max_input_shape = {"input_data": [16]}
self.dynamic_shape.opt_input_shape = {"input_data": [16]}
elif self.dims == 2:
self.dynamic_shape.min_input_shape = {"input_data": [1, 3]}
self.dynamic_shape.max_input_shape = {"input_data": [1, 3]}
self.dynamic_shape.opt_input_shape = {"input_data": [1, 3]}
elif self.dims == 3:
if attrs[0]["data_format"] == "NCHW":
self.dynamic_shape.min_input_shape = {
"input_data": [1, 1],
"input_data": [1, 3, 16]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 32],
"input_data": [4, 3, 16]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2, 3],
"input_data": [1, 3, 16]
}
elif self.dim2 != 0 and self.dim3 != 0:
elif attrs[0]["data_format"] == "NHWC":
self.dynamic_shape.min_input_shape = {
"input_data": [1, 1, 1, 1],
"input_data": [1, 16, 3]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 3, 16, 32],
"input_data": [4, 16, 3]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2, 3, 16, 32],
"input_data": [1, 16, 3]
}
elif self.dim3 == 0:
else:
raise AssertionError()
else:
if attrs[0]["data_format"] == "NCHW":
self.dynamic_shape.min_input_shape = {
"input_data": [1, 1, 1],
"input_data": [1, 3, 16, 32]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 3, 32],
"input_data": [4, 3, 16, 32]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2, 3, 16],
"input_data": [1, 3, 16, 32]
}
elif attrs[0]["data_format"] == "NHWC":
self.dynamic_shape.min_input_shape = {
"input_data": [1, 16, 32, 3]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 16, 32, 3]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [1, 16, 32, 3]
}
else:
raise AssertionError()
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
......@@ -203,12 +198,7 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
]
def generate_trt_nodes_num(attrs, dynamic_shape):
if (
not dynamic_shape
and self.dim1 == 0
and self.dim2 == 0
and self.dim3 == 0
):
if not dynamic_shape and (self.dims == 1 or self.dims == 0):
return 0, 3
return 1, 2
......@@ -234,23 +224,7 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
attrs, True
), (1e-3, 1e-3)
def add_skip_trt_case(self):
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000:
def teller(program_config, predictor_config):
if not predictor_config.tensorrt_dynamic_shape_enabled():
return True
return False
self.add_skip_case(
teller,
SkipReasons.TRT_NOT_IMPLEMENTED,
"Need to repair the case: the output of GPU and tensorrt has diff in trt6, the prelu static plugin has bug.",
)
def test(self):
self.add_skip_trt_case()
self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册