test_xpu_gather_squeeze_pass.py 3.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from functools import partial

import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig


class TestGatherAddTransposePass(PassAutoScanTest):
    def sample_predictor_configs(self, program_config):
        config = self.create_inference_config(use_xpu=True)
        yield config, [
            "transpose2",
            "gather",
            "transpose2",
            "gather",
            "squeeze2",
            "squeeze2",
        ], (1e-3, 1e-3)

    def sample_program_config(self, draw):
        x_shape = draw(
            st.lists(
                st.integers(min_value=1, max_value=4), min_size=3, max_size=3
            )
        )

        def generate_data(shape):
            return np.random.random(shape).astype(np.float32)

        def generate_index(*args, **kwargs):
            return np.array([0]).astype(np.int64)

        axis = 2
        axes = [2]
        gather_op0 = OpConfig(
            "gather",
            inputs={"X": ["gather_in"], "Index": ["gather_index0"]},
            outputs={"Out": ["gather_out0"]},
            axis=axis,
        )

        gather_op1 = OpConfig(
            "gather",
            inputs={"X": ["gather_in"], "Index": ["gather_index1"]},
            outputs={"Out": ["gather_out1"]},
            axis=axis,
        )

        squeeze_op0 = OpConfig(
            "squeeze2",
            inputs={
                "X": ["gather_out0"],
            },
            outputs={"Out": ["squeeze_out0"]},
            axes=axes,
        )

        squeeze_op1 = OpConfig(
            "squeeze2",
            inputs={
                "X": ["gather_out1"],
            },
            outputs={"Out": ["squeeze_out1"]},
            axes=axes,
        )

        ops = [gather_op0, gather_op1, squeeze_op0, squeeze_op1]

        program_config = ProgramConfig(
            ops=ops,
            inputs={
                "gather_in": TensorConfig(
                    data_gen=partial(generate_data, x_shape)
                ),
                "gather_index0": TensorConfig(data_gen=partial(generate_index)),
                "gather_index1": TensorConfig(data_gen=partial(generate_index)),
            },
            weights={},
            outputs=["squeeze_out0", "squeeze_out1"],
        )
        return program_config

    def test(self):
        self.run_and_statis(
            quant=False, max_examples=25, passes=["gather_squeeze_pass"]
        )


if __name__ == "__main__":
    unittest.main()