test_trt_convert_gather.py 9.5 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import unittest
16
from functools import partial
17
from typing import List
18 19 20 21 22 23

import numpy as np
from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import SkipReasons, TrtLayerAutoScanTest

import paddle.inference as paddle_infer
24 25 26 27 28 29


class TrtConvertGatherTest(TrtLayerAutoScanTest):
    def is_program_valid(self, program_config: ProgramConfig) -> bool:
        inputs = program_config.inputs
        attrs = [
30
            program_config.ops[i].attrs for i in range(len(program_config.ops))
31 32 33 34 35 36 37 38 39 40 41 42 43
        ]
        if len(inputs['input_data'].shape) <= attrs[0]['axis']:
            return False

        return True

    def sample_program_configs(self):
        def generate_input1(shape):
            return np.random.random(shape).astype(np.float32)

        def generate_input2(index):
            return np.array(index).astype(np.int32)

F
feng_shuai 已提交
44 45 46
        def generate_input4(index):
            return np.array(index).astype(np.int64)

47 48 49 50 51 52 53
        def generate_input3(axis):
            return np.array([axis]).astype(np.int32)

        for shape in [[32], [16, 64], [32, 16, 16], [32, 64, 16, 32]]:
            for index in [[1, 4], [4, 8]]:
                for axis in [0, 1, 2, 3]:
                    for overwrite in [True, False]:
54 55 56
                        for input in [
                            {"X": ["input_data"], "Index": ["index_data"]},
                            {
57 58
                                "X": ["input_data"],
                                "Index": ["index_data"],
59 60 61
                                "Axis": ["axis_data"],
                            },
                        ]:
F
feng_shuai 已提交
62 63 64 65 66 67
                            for index_type_int32 in [True, False]:
                                self.shape = shape
                                self.axis = axis
                                self.input_num = len(input)
                                self.index_type_int32 = index_type_int32
                                dics = [{"overwrite": overwrite, "axis": axis}]
68 69 70 71 72 73 74 75
                                ops_config = [
                                    {
                                        "op_type": "gather",
                                        "op_inputs": input,
                                        "op_outputs": {"Out": ["output_data"]},
                                        "op_attrs": dics[0],
                                    }
                                ]
F
feng_shuai 已提交
76 77 78 79 80 81
                                ops = self.generate_op_config(ops_config)

                                program_config = ProgramConfig(
                                    ops=ops,
                                    weights={},
                                    inputs={
82 83 84 85 86 87 88 89
                                        "input_data": TensorConfig(
                                            data_gen=partial(
                                                generate_input1, shape
                                            )
                                        ),
                                        "index_data": TensorConfig(
                                            data_gen=partial(
                                                generate_input2
90
                                                if index_type_int32
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
                                                else generate_input4,
                                                index,
                                            )
                                        ),
                                    }
                                    if len(input) == 2
                                    else {
                                        "input_data": TensorConfig(
                                            data_gen=partial(
                                                generate_input1, shape
                                            )
                                        ),
                                        "index_data": TensorConfig(
                                            data_gen=partial(
                                                generate_input2, index
                                            )
                                        ),
                                        "axis_data": TensorConfig(
                                            data_gen=partial(
                                                generate_input3, axis
                                            )
                                        ),
F
feng_shuai 已提交
113
                                    },
114 115
                                    outputs=["output_data"],
                                )
F
feng_shuai 已提交
116 117

                                yield program_config
118 119

    def sample_predictor_configs(
120 121
        self, program_config
    ) -> (paddle_infer.Config, List[int], float):
122 123 124 125
        def generate_dynamic_shape(attrs):
            if len(self.shape) == 1:
                self.dynamic_shape.min_input_shape = {
                    "input_data": [4],
126
                    "index_data": [1],
127 128 129
                }
                self.dynamic_shape.max_input_shape = {
                    "input_data": [128],
130
                    "index_data": [4],
131 132 133
                }
                self.dynamic_shape.opt_input_shape = {
                    "input_data": [16],
134
                    "index_data": [2],
135 136 137 138
                }
            elif len(self.shape) == 2:
                self.dynamic_shape.min_input_shape = {
                    "input_data": [2, 4],
139
                    "index_data": [1],
140 141 142
                }
                self.dynamic_shape.max_input_shape = {
                    "input_data": [256, 256],
143
                    "index_data": [4],
144 145 146
                }
                self.dynamic_shape.opt_input_shape = {
                    "input_data": [64, 32],
147
                    "index_data": [2],
148 149 150 151
                }
            elif len(self.shape) == 3:
                self.dynamic_shape.min_input_shape = {
                    "input_data": [2, 4, 4],
152
                    "index_data": [1],
153 154 155
                }
                self.dynamic_shape.max_input_shape = {
                    "input_data": [128, 256, 256],
156
                    "index_data": [4],
157 158 159
                }
                self.dynamic_shape.opt_input_shape = {
                    "input_data": [16, 64, 32],
160
                    "index_data": [2],
161 162 163 164
                }
            elif len(self.shape) == 4:
                self.dynamic_shape.min_input_shape = {
                    "input_data": [2, 4, 4, 2],
165
                    "index_data": [1],
166 167
                }
                self.dynamic_shape.max_input_shape = {
168
                    "input_data": [128, 256, 64, 128],
169
                    "index_data": [4],
170 171 172
                }
                self.dynamic_shape.opt_input_shape = {
                    "input_data": [16, 64, 16, 32],
173
                    "index_data": [2],
174 175 176 177 178 179 180 181 182 183 184
                }

        def clear_dynamic_shape():
            self.dynamic_shape.max_input_shape = {}
            self.dynamic_shape.min_input_shape = {}
            self.dynamic_shape.opt_input_shape = {}

        def generate_trt_nodes_num(dynamic_shape):
            if self.input_num == 3:
                return 0, 5
            else:
185
                if dynamic_shape:
186 187 188 189 190
                    return 1, 3
                else:
                    return 0, 4

        attrs = [
191
            program_config.ops[i].attrs for i in range(len(program_config.ops))
192 193 194 195 196 197
        ]

        # for static_shape
        clear_dynamic_shape()
        self.trt_param.precision = paddle_infer.PrecisionType.Float32
        yield self.create_inference_config(), generate_trt_nodes_num(
198 199
            False
        ), 1e-5
200 201
        self.trt_param.precision = paddle_infer.PrecisionType.Half
        yield self.create_inference_config(), generate_trt_nodes_num(
202 203
            False
        ), 1e-3
204 205 206 207 208 209

        # for dynamic_shape
        generate_dynamic_shape(attrs)
        self.trt_param.precision = paddle_infer.PrecisionType.Float32
        yield self.create_inference_config(), generate_trt_nodes_num(True), 1e-5
        self.trt_param.precision = paddle_infer.PrecisionType.Half
F
feng_shuai 已提交
210
        yield self.create_inference_config(), generate_trt_nodes_num(True), 1e-3
211 212

    def add_skip_trt_case(self):
F
feng_shuai 已提交
213 214 215 216 217 218
        ver = paddle_infer.get_trt_compile_version()
        if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000:

            def teller1(program_config, predictor_config):
                if len(self.dynamic_shape.min_input_shape) != 0:
                    inputs = program_config.inputs
219 220 221 222
                    if (
                        len(inputs['input_data'].shape) == 1
                        or len(inputs['index_data'].shape) == 1
                    ):
F
feng_shuai 已提交
223 224 225 226
                        return True
                return False

            self.add_skip_case(
227 228 229
                teller1,
                SkipReasons.TRT_NOT_SUPPORT,
                "Need to repair the case: trt reshape out failed for dynamic shape mode when inputs' dims==1. under trt7.0 ",
F
feng_shuai 已提交
230
            )
231 232 233 234 235 236 237 238

    def test(self):
        self.add_skip_trt_case()
        self.run_test()


if __name__ == "__main__":
    unittest.main()