未验证 提交 56f108ff 编写于 作者: W wangxinxin08 提交者: GitHub

filter unsupported inputs for elementwise op in op teller (#41253)

* filter unsupported inputs for elementwise op in op teller

* add unittest for corner case
上级 5d3fd4fe
......@@ -1011,6 +1011,21 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
VLOG(3) << "Now trt may not support two 1d tensor elementwise op.";
return false;
}
if (op_type == "elementwise_add" || op_type == "elementwise_mul") {
if (x_var_desc->Persistable()) {
VLOG(3) << "Input X is a parameter which is not supported for "
"elementwise_add/elementwise_mul in tensorrt, swap x and "
"y will work";
return false;
}
}
if (op_type == "elementwise_sub" || op_type == "elementwise_div") {
if (x_var_desc->Persistable() || y_var_desc->Persistable()) {
VLOG(3) << "Input X or Input Y is a parameter which is not supported "
"for elementwise_sub/elementwise_div in tensorrt";
return false;
}
}
}
if (op_type == "stack") {
......
......@@ -397,5 +397,139 @@ class TrtConvertElementwiseTest_two_input_with_broadcast(TrtLayerAutoScanTest):
self.run_test()
class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input(shape):
return np.random.random(shape).astype(np.float32)
def generate_weight():
return np.random.randn(32).astype(np.float32)
for batch in [1, 2, 4]:
for shape in [[32], [batch, 32], [batch, 32, 32],
[batch, 32, 16, 32]]:
for op_type in [
"elementwise_add", "elementwise_mul", "elementwise_sub",
"elementwise_div"
]:
for axis in [-1 if len(shape) == 1 else 1]:
self.dims = len(shape)
dics = [{"axis": axis}]
ops_config = [{
"op_type": op_type,
"op_inputs": {
"X": ["weight"],
"Y": ["input_data"]
},
"op_outputs": {
"Out": ["output_data"]
},
"op_attrs": dics[0]
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"weight":
TensorConfig(data_gen=partial(generate_weight))
},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input, shape)),
},
outputs=["output_data"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
# The input.dims[1] must be equal to the weight's length.
if self.dims == 1:
self.dynamic_shape.min_input_shape = {"input_data": [4]}
self.dynamic_shape.max_input_shape = {"input_data": [256]}
self.dynamic_shape.opt_input_shape = {"input_data": [16]}
elif self.dims == 2:
self.dynamic_shape.min_input_shape = {"input_data": [1, 32]}
self.dynamic_shape.max_input_shape = {"input_data": [4, 32]}
self.dynamic_shape.opt_input_shape = {"input_data": [2, 32]}
elif self.dims == 3:
self.dynamic_shape.min_input_shape = {"input_data": [1, 32, 4]}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 32, 256]
}
self.dynamic_shape.opt_input_shape = {"input_data": [2, 32, 16]}
elif self.dims == 4:
self.dynamic_shape.min_input_shape = {
"input_data": [1, 32, 4, 4]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 32, 128, 256]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2, 32, 32, 16]
}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
if self.dims == 1:
return 0, 3
return 1, 2
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
input_x_names = program_config.ops[0].inputs["X"]
for weight_name in program_config.weights:
if weight_name in input_x_names:
return True
op_type = program_config.ops[0].type
if op_type in ["elementwise_sub", "elementwise_div"]:
input_y_names = program_config.ops[0].inputs["Y"]
for weight_name in program_config.weights:
if weight_name in input_y_names:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_SUPPORT,
"Input X should not be parameters in elementwise op and Input Y should not be parameters in elementwise_sub or elementwise_div op"
)
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册