未验证 提交 32dae48a 编写于 作者: Z zhoutianzi666 提交者: GitHub

add unitest for reshpe 0 dims (#53685)

上级 4a97ba5d
...@@ -431,5 +431,99 @@ class TrtConvertReshapeTest3(TrtLayerAutoScanTest): ...@@ -431,5 +431,99 @@ class TrtConvertReshapeTest3(TrtLayerAutoScanTest):
self.run_test() self.run_test()
class TrtConvertReshapeZeroDimsTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input1(attrs: List[Dict[str, Any]]):
if self.dims > 0:
self.input_shape = [1] * self.dims
return np.random.random(self.input_shape).astype(np.float32)
elif self.dims == 0:
self.input_shape = []
return np.random.random([]).astype(np.float32)
for dims in [0, 1, 2, 3]:
for shape in [
[],
[1, 1],
]:
dics = [
{
"shape": shape,
},
]
self.dims = dims
dics_intput = [{"X": ["reshape_input"]}]
ops_config = [
{
"op_type": "reshape",
"op_inputs": dics_intput[0],
"op_outputs": {"Out": ["reshape_out"]},
"op_attrs": dics[0],
}
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"reshape_input": TensorConfig(
data_gen=partial(generate_input1, dics)
)
},
outputs=["reshape_out"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"reshape_input": self.input_shape
}
self.dynamic_shape.max_input_shape = {
"reshape_input": self.input_shape
}
self.dynamic_shape.opt_input_shape = {
"reshape_input": self.input_shape
}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
# only test dynamic shape mode
return 1, 2
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-3
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册